Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Databricks] Fix issue for using single notebook task launch without taskgroup #162

Merged
merged 8 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosmos/providers/databricks/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

JOBS_API_VERSION = os.getenv("JOBS_API_VERSION", "2.1")
26 changes: 20 additions & 6 deletions cosmos/providers/databricks/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from databricks_cli.runs.api import RunsApi
from databricks_cli.sdk.api_client import ApiClient

from cosmos.providers.databricks.constants import JOBS_API_VERSION


class DatabricksNotebookOperator(BaseOperator):
"""
Expand Down Expand Up @@ -98,9 +100,9 @@ def convert_to_databricks_workflow_task(
if hasattr(self.task_group, "notebook_packages"):
self.notebook_packages.extend(self.task_group.notebook_packages)
result = {
"task_key": self.dag_id + "__" + self.task_id.replace(".", "__"),
"task_key": self._get_databricks_task_id(self.task_id),
"depends_on": [
{"task_key": self.dag_id + "__" + t.replace(".", "__")}
{"task_key": self._get_databricks_task_id(t)}
for t in self.upstream_task_ids
if t in relevant_upstreams
],
Expand All @@ -116,6 +118,10 @@ def convert_to_databricks_workflow_task(
}
return result

def _get_databricks_task_id(self, task_id: str):
"""Get the databricks task ID using dag_id and task_id. removes illegal characters."""
return self.dag_id + "__" + task_id.replace(".", "__")

def monitor_databricks_job(self):
"""Monitor the Databricks job until it completes. Raises Airflow exception if the job fails."""
api_client = self._get_api_client()
Expand All @@ -124,13 +130,18 @@ def monitor_databricks_job(self):
self._wait_for_pending_task(current_task, runs_api)
self._wait_for_running_task(current_task, runs_api)
self._wait_for_terminating_task(current_task, runs_api)
final_state = runs_api.get_run(current_task["run_id"])["state"]
final_state = runs_api.get_run(
current_task["run_id"], version=JOBS_API_VERSION
)["state"]
self._handle_final_state(final_state)

def _get_current_databricks_task(self, runs_api):
return {
x["task_key"]: x for x in runs_api.get_run(self.databricks_run_id)["tasks"]
}[self.dag_id + "__" + self.task_id.replace(".", "__")]
x["task_key"]: x
for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
"tasks"
]
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
}[self._get_databricks_task_id(self.task_id)]

def _handle_final_state(self, final_state):
if final_state.get("life_cycle_state", None) != "TERMINATED":
Expand All @@ -145,7 +156,9 @@ def _handle_final_state(self, final_state):
)

def _get_lifestyle_state(self, current_task, runs_api):
return runs_api.get_run(current_task["run_id"])["state"]["life_cycle_state"]
return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[
"state"
]["life_cycle_state"]

def _wait_on_state(self, current_task, runs_api, state):
while self._get_lifestyle_state(current_task, runs_api) == state:
Expand Down Expand Up @@ -174,6 +187,7 @@ def launch_notebook_job(self):
"""Launch the notebook as a one-time job to Databricks."""
api_client = self._get_api_client()
run_json = {
"run_name": self._get_databricks_task_id(self.task_id),
"notebook_task": {
"notebook_path": self.notebook_path,
"base_parameters": {"source": self.source},
Expand Down
2 changes: 1 addition & 1 deletion docs/databricks/workflows.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The DatabricksWorkflowTaskGroup is designed to look and function like a standard
with the added ability to include specific Databricks arguments.
An example of how to use the DatabricksWorkflowTaskGroup can be seen in the following code snippet:

.. exampleinclude:: /../astronomer/providers/databricks/example_dags/example_databricks_workflow.py
.. literalinclude:: /../examples/databricks/example_databricks_workflow.py
:language: python
:dedent: 4
:start-after: [START howto_databricks_workflow_notebook]
Expand Down
61 changes: 61 additions & 0 deletions examples/databricks/example_databricks_notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Example DAG for using the DatabricksNotebookOperator."""
import os
from datetime import timedelta

from airflow.models.dag import DAG
from airflow.utils.timezone import datetime

from cosmos.providers.databricks.notebook import DatabricksNotebookOperator

EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}

DATABRICKS_CONN_ID = os.getenv("ASTRO_DATABRICKS_CONN_ID", "databricks_conn")
NEW_CLUSTER_SPEC = {
"cluster_name": "",
"spark_version": "11.3.x-scala2.12",
"aws_attributes": {
"first_on_demand": 1,
"availability": "SPOT_WITH_FALLBACK",
"zone_id": "us-east-2b",
"spot_bid_price_percent": 100,
"ebs_volume_count": 0,
},
"node_type_id": "i3.xlarge",
"spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"},
"enable_elastic_disk": False,
"data_security_mode": "LEGACY_SINGLE_USER_STANDARD",
"runtime_engine": "STANDARD",
"num_workers": 8,
}

dag = DAG(
dag_id="example_databricks_notebook",
start_date=datetime(2022, 1, 1),
schedule_interval=None,
catchup=False,
default_args=default_args,
tags=["example", "async", "databricks"],
)
with dag:
notebook_1 = DatabricksNotebookOperator(
task_id="notebook_1",
databricks_conn_id=DATABRICKS_CONN_ID,
notebook_path="/Shared/Notebook_1",
notebook_packages=[
{
"pypi": {
"package": "simplejson==3.18.0",
"repo": "https://pypi.org/simple",
}
},
{"pypi": {"package": "Faker"}},
],
source="WORKSPACE",
job_cluster_key="random_cluster_key",
new_cluster=NEW_CLUSTER_SPEC,
)
37 changes: 26 additions & 11 deletions tests/databricks/test_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,17 @@ def test_databricks_notebook_operator_with_taskgroup(
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator.monitor_databricks_job"
)
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
)
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
)
@mock.patch("cosmos.providers.databricks.notebook.RunsApi")
def test_databricks_notebook_operator_without_taskgroup_new_cluster(
mock_runs_api, mock_api_client, mock_monitor, dag
mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag
):
mock_get_databricks_task_id.return_value = "1234"
mock_runs_api.return_value = mock.MagicMock()
with dag:
DatabricksNotebookOperator(
Expand All @@ -132,6 +136,7 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster(
dag.test()
mock_runs_api.return_value.submit_run.assert_called_once_with(
{
"run_name": "1234",
"notebook_task": {
"notebook_path": "/foo/bar",
"base_parameters": {"source": "WORKSPACE"},
Expand All @@ -146,13 +151,17 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster(
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator.monitor_databricks_job"
)
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
)
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
)
@mock.patch("cosmos.providers.databricks.notebook.RunsApi")
def test_databricks_notebook_operator_without_taskgroup_existing_cluster(
mock_runs_api, mock_api_client, mock_monitor, dag
mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag
):
mock_get_databricks_task_id.return_value = "1234"
mock_runs_api.return_value = mock.MagicMock()
with dag:
DatabricksNotebookOperator(
Expand All @@ -170,6 +179,7 @@ def test_databricks_notebook_operator_without_taskgroup_existing_cluster(
dag.test()
mock_runs_api.return_value.submit_run.assert_called_once_with(
{
"run_name": "1234",
"notebook_task": {
"notebook_path": "/foo/bar",
"base_parameters": {"source": "WORKSPACE"},
Expand Down Expand Up @@ -273,7 +283,7 @@ def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_op
{"state": {"life_cycle_state": "RUNNING"}},
]
databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123")
mock_runs_api.get_run.assert_called_with("123", version="2.1")
assert mock_runs_api.get_run.call_count == 2
mock_runs_api.reset_mock()

Expand All @@ -290,7 +300,7 @@ def test_wait_for_terminating_task(
{"state": {"life_cycle_state": "TERMINATED"}},
]
databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123")
mock_runs_api.get_run.assert_called_with("123", version="2.1")
assert mock_runs_api.get_run.call_count == 3
mock_runs_api.reset_mock()

Expand All @@ -305,7 +315,7 @@ def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_op
{"state": {"life_cycle_state": "TERMINATED"}},
]
databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123")
mock_runs_api.get_run.assert_called_with("123", version="2.1")
assert mock_runs_api.get_run.call_count == 3
mock_runs_api.reset_mock()

Expand All @@ -328,15 +338,16 @@ def test_get_lifestyle_state(databricks_notebook_operator):
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
)
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_current_databricks_task"
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
)
def test_monitor_databricks_job_success(
mock_get_task_name,
mock_get_databricks_task_id,
mock_get_api_client,
mock_runs_api,
mock_databricks_hook,
databricks_notebook_operator,
):
mock_get_databricks_task_id.return_value = "1"
# Define the expected response
response = {
"run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1",
Expand All @@ -354,8 +365,11 @@ def test_monitor_databricks_job_success(
}
mock_runs_api.return_value.get_run.return_value = response

databricks_notebook_operator.databricks_run_id = "1234"
databricks_notebook_operator.databricks_run_id = "1"
databricks_notebook_operator.monitor_databricks_job()
mock_runs_api.return_value.get_run.assert_called_with(
databricks_notebook_operator.databricks_run_id, version="2.1"
)


@mock.patch("cosmos.providers.databricks.notebook.DatabricksHook")
Expand All @@ -364,15 +378,16 @@ def test_monitor_databricks_job_success(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
)
@mock.patch(
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_current_databricks_task"
"cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
)
def test_monitor_databricks_job_fail(
mock_get_task_name,
mock_get_databricks_task_id,
mock_get_api_client,
mock_runs_api,
mock_databricks_hook,
databricks_notebook_operator,
):
mock_get_databricks_task_id.return_value = "1"
# Define the expected response
response = {
"run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1",
Expand All @@ -390,6 +405,6 @@ def test_monitor_databricks_job_fail(
}
mock_runs_api.return_value.get_run.return_value = response

databricks_notebook_operator.databricks_run_id = "1234"
databricks_notebook_operator.databricks_run_id = "1"
with pytest.raises(AirflowException):
databricks_notebook_operator.monitor_databricks_job()