From a8870337c79d327d419718bd7706549df21fb1ee Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Sat, 25 Feb 2023 20:27:41 +0530 Subject: [PATCH 1/8] Print run --- cosmos/providers/databricks/notebook.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index a299f7ede..f478314e4 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -193,6 +193,7 @@ def launch_notebook_job(self): runs_api = RunsApi(api_client) run = runs_api.submit_run(run_json) self.databricks_run_id = run["run_id"] + print(run) return run def execute(self, context: Context) -> Any: From 77417a39e16e7f2431cd09534b1ee4e98e6daf01 Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Sat, 25 Feb 2023 20:47:25 +0530 Subject: [PATCH 2/8] Additional prints --- cosmos/providers/databricks/notebook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index f478314e4..8ba03768b 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -128,6 +128,7 @@ def monitor_databricks_job(self): self._handle_final_state(final_state) def _get_current_databricks_task(self, runs_api): + print(self.databricks_run_id, runs_api.get_run(self.databricks_run_id)) return { x["task_key"]: x for x in runs_api.get_run(self.databricks_run_id)["tasks"] }[self.dag_id + "__" + self.task_id.replace(".", "__")] @@ -193,7 +194,7 @@ def launch_notebook_job(self): runs_api = RunsApi(api_client) run = runs_api.submit_run(run_json) self.databricks_run_id = run["run_id"] - print(run) + print(run, self.databricks_run_id) return run def execute(self, context: Context) -> Any: From f3c69b5deef8dd78cca57108b94d32d9ef052ea8 Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Sat, 25 Feb 2023 21:05:26 +0530 Subject: [PATCH 3/8] Pass API version to API call --- cosmos/providers/databricks/notebook.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index 8ba03768b..3312f5a2f 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -128,9 +128,13 @@ def monitor_databricks_job(self): self._handle_final_state(final_state) def _get_current_databricks_task(self, runs_api): - print(self.databricks_run_id, runs_api.get_run(self.databricks_run_id)) + print( + self.databricks_run_id, + runs_api.get_run(self.databricks_run_id, version="2.1"), + ) return { - x["task_key"]: x for x in runs_api.get_run(self.databricks_run_id)["tasks"] + x["task_key"]: x + for x in runs_api.get_run(self.databricks_run_id, version="2.1")["tasks"] }[self.dag_id + "__" + self.task_id.replace(".", "__")] def _handle_final_state(self, final_state): From fe36ff2bb1ce44adcedaa2c32b7fefff74564217 Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Sat, 25 Feb 2023 21:27:46 +0530 Subject: [PATCH 4/8] Pass run-name so that it maps to task_key --- cosmos/providers/databricks/notebook.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index 3312f5a2f..172bef5c5 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -179,6 +179,7 @@ def launch_notebook_job(self): """Launch the notebook as a one-time job to Databricks.""" api_client = self._get_api_client() run_json = { + "run_name": self.dag_id + "__" + self.task_id.replace(".", "__"), "notebook_task": { "notebook_path": self.notebook_path, "base_parameters": {"source": self.source}, From 862694cf9a68ca9e1057646025fad0200d6ac6af Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Sat, 25 Feb 2023 22:14:44 +0530 Subject: [PATCH 5/8] Add example DAG for single notebook launch --- cosmos/providers/databricks/constants.py | 3 + cosmos/providers/databricks/notebook.py | 11 ++-- .../databricks/example_databricks_notebook.py | 61 +++++++++++++++++++ 3 files changed, 69 insertions(+), 6 deletions(-) create mode 100644 cosmos/providers/databricks/constants.py create mode 100644 examples/databricks/example_databricks_notebook.py diff --git a/cosmos/providers/databricks/constants.py b/cosmos/providers/databricks/constants.py new file mode 100644 index 000000000..13d81c9b9 --- /dev/null +++ b/cosmos/providers/databricks/constants.py @@ -0,0 +1,3 @@ +import os + +JOBS_API_VERSION = os.getenv("JOBS_API_VERSION", "2.1") diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index 172bef5c5..f5837fac8 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -11,6 +11,8 @@ from databricks_cli.runs.api import RunsApi from databricks_cli.sdk.api_client import ApiClient +from cosmos.providers.databricks.constants import JOBS_API_VERSION + class DatabricksNotebookOperator(BaseOperator): """ @@ -128,13 +130,11 @@ def monitor_databricks_job(self): self._handle_final_state(final_state) def _get_current_databricks_task(self, runs_api): - print( - self.databricks_run_id, - runs_api.get_run(self.databricks_run_id, version="2.1"), - ) return { x["task_key"]: x - for x in runs_api.get_run(self.databricks_run_id, version="2.1")["tasks"] + for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ + "tasks" + ] }[self.dag_id + "__" + self.task_id.replace(".", "__")] def _handle_final_state(self, final_state): @@ -199,7 +199,6 @@ def launch_notebook_job(self): runs_api = RunsApi(api_client) run = runs_api.submit_run(run_json) self.databricks_run_id = run["run_id"] - print(run, self.databricks_run_id) return run def execute(self, context: Context) -> Any: diff --git a/examples/databricks/example_databricks_notebook.py b/examples/databricks/example_databricks_notebook.py new file mode 100644 index 000000000..beb96f5a0 --- /dev/null +++ b/examples/databricks/example_databricks_notebook.py @@ -0,0 +1,61 @@ +"""Example DAG for using the DatabricksNotebookOperator.""" +import os +from datetime import timedelta + +from airflow.models.dag import DAG +from airflow.utils.timezone import datetime + +from cosmos.providers.databricks.notebook import DatabricksNotebookOperator + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), + "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), +} + +DATABRICKS_CONN_ID = os.getenv("ASTRO_DATABRICKS_CONN_ID", "databricks_conn") +NEW_CLUSTER_SPEC = { + "cluster_name": "", + "spark_version": "11.3.x-scala2.12", + "aws_attributes": { + "first_on_demand": 1, + "availability": "SPOT_WITH_FALLBACK", + "zone_id": "us-east-2b", + "spot_bid_price_percent": 100, + "ebs_volume_count": 0, + }, + "node_type_id": "i3.xlarge", + "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"}, + "enable_elastic_disk": False, + "data_security_mode": "LEGACY_SINGLE_USER_STANDARD", + "runtime_engine": "STANDARD", + "num_workers": 8, +} + +dag = DAG( + dag_id="example_databricks_notebook", + start_date=datetime(2022, 1, 1), + schedule_interval=None, + catchup=False, + default_args=default_args, + tags=["example", "async", "databricks"], +) +with dag: + notebook_1 = DatabricksNotebookOperator( + task_id="notebook_1", + databricks_conn_id=DATABRICKS_CONN_ID, + notebook_path="/Shared/Notebook_1", + notebook_packages=[ + { + "pypi": { + "package": "simplejson==3.18.0", + "repo": "https://pypi.org/simple", + } + }, + {"pypi": {"package": "Faker"}}, + ], + source="WORKSPACE", + job_cluster_key="random_cluster_key", + new_cluster=NEW_CLUSTER_SPEC, + ) From af3f3c11de946aea50252ce400931a5d18c6e4a4 Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Mon, 27 Feb 2023 22:12:25 +0530 Subject: [PATCH 6/8] Fix Databricks example DAG path in docs --- docs/databricks/workflows.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/databricks/workflows.rst b/docs/databricks/workflows.rst index f2140fbc1..94717ad9b 100644 --- a/docs/databricks/workflows.rst +++ b/docs/databricks/workflows.rst @@ -12,7 +12,7 @@ The DatabricksWorkflowTaskGroup is designed to look and function like a standard with the added ability to include specific Databricks arguments. An example of how to use the DatabricksWorkflowTaskGroup can be seen in the following code snippet: -.. exampleinclude:: /../astronomer/providers/databricks/example_dags/example_databricks_workflow.py +.. literalinclude:: /../examples/databricks/example_databricks_workflow.py :language: python :dedent: 4 :start-after: [START howto_databricks_workflow_notebook] From efda95afad4c8f5ca24654e9e225948bc940ef87 Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Tue, 28 Feb 2023 15:00:08 +0530 Subject: [PATCH 7/8] Use latest version across all get_run calls and add tests --- cosmos/providers/databricks/notebook.py | 18 ++++++++---- tests/databricks/test_notebook.py | 37 +++++++++++++++++-------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index f5837fac8..8035f7993 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -100,7 +100,7 @@ def convert_to_databricks_workflow_task( if hasattr(self.task_group, "notebook_packages"): self.notebook_packages.extend(self.task_group.notebook_packages) result = { - "task_key": self.dag_id + "__" + self.task_id.replace(".", "__"), + "task_key": self._get_databricks_task_id(), "depends_on": [ {"task_key": self.dag_id + "__" + t.replace(".", "__")} for t in self.upstream_task_ids @@ -118,6 +118,10 @@ def convert_to_databricks_workflow_task( } return result + def _get_databricks_task_id(self): + """Get the databricks task ID using dag_id and task_id. removes illegal characters.""" + return self.dag_id + "__" + self.task_id.replace(".", "__") + def monitor_databricks_job(self): """Monitor the Databricks job until it completes. Raises Airflow exception if the job fails.""" api_client = self._get_api_client() @@ -126,7 +130,9 @@ def monitor_databricks_job(self): self._wait_for_pending_task(current_task, runs_api) self._wait_for_running_task(current_task, runs_api) self._wait_for_terminating_task(current_task, runs_api) - final_state = runs_api.get_run(current_task["run_id"])["state"] + final_state = runs_api.get_run( + current_task["run_id"], version=JOBS_API_VERSION + )["state"] self._handle_final_state(final_state) def _get_current_databricks_task(self, runs_api): @@ -135,7 +141,7 @@ def _get_current_databricks_task(self, runs_api): for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ "tasks" ] - }[self.dag_id + "__" + self.task_id.replace(".", "__")] + }[self._get_databricks_task_id()] def _handle_final_state(self, final_state): if final_state.get("life_cycle_state", None) != "TERMINATED": @@ -150,7 +156,9 @@ def _handle_final_state(self, final_state): ) def _get_lifestyle_state(self, current_task, runs_api): - return runs_api.get_run(current_task["run_id"])["state"]["life_cycle_state"] + return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[ + "state" + ]["life_cycle_state"] def _wait_on_state(self, current_task, runs_api, state): while self._get_lifestyle_state(current_task, runs_api) == state: @@ -179,7 +187,7 @@ def launch_notebook_job(self): """Launch the notebook as a one-time job to Databricks.""" api_client = self._get_api_client() run_json = { - "run_name": self.dag_id + "__" + self.task_id.replace(".", "__"), + "run_name": self._get_databricks_task_id(), "notebook_task": { "notebook_path": self.notebook_path, "base_parameters": {"source": self.source}, diff --git a/tests/databricks/test_notebook.py b/tests/databricks/test_notebook.py index f5403b06b..1c4044241 100644 --- a/tests/databricks/test_notebook.py +++ b/tests/databricks/test_notebook.py @@ -108,13 +108,17 @@ def test_databricks_notebook_operator_with_taskgroup( @mock.patch( "cosmos.providers.databricks.notebook.DatabricksNotebookOperator.monitor_databricks_job" ) +@mock.patch( + "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id" +) @mock.patch( "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client" ) @mock.patch("cosmos.providers.databricks.notebook.RunsApi") def test_databricks_notebook_operator_without_taskgroup_new_cluster( - mock_runs_api, mock_api_client, mock_monitor, dag + mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag ): + mock_get_databricks_task_id.return_value = "1234" mock_runs_api.return_value = mock.MagicMock() with dag: DatabricksNotebookOperator( @@ -132,6 +136,7 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster( dag.test() mock_runs_api.return_value.submit_run.assert_called_once_with( { + "run_name": "1234", "notebook_task": { "notebook_path": "/foo/bar", "base_parameters": {"source": "WORKSPACE"}, @@ -146,13 +151,17 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster( @mock.patch( "cosmos.providers.databricks.notebook.DatabricksNotebookOperator.monitor_databricks_job" ) +@mock.patch( + "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id" +) @mock.patch( "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client" ) @mock.patch("cosmos.providers.databricks.notebook.RunsApi") def test_databricks_notebook_operator_without_taskgroup_existing_cluster( - mock_runs_api, mock_api_client, mock_monitor, dag + mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag ): + mock_get_databricks_task_id.return_value = "1234" mock_runs_api.return_value = mock.MagicMock() with dag: DatabricksNotebookOperator( @@ -170,6 +179,7 @@ def test_databricks_notebook_operator_without_taskgroup_existing_cluster( dag.test() mock_runs_api.return_value.submit_run.assert_called_once_with( { + "run_name": "1234", "notebook_task": { "notebook_path": "/foo/bar", "base_parameters": {"source": "WORKSPACE"}, @@ -273,7 +283,7 @@ def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_op {"state": {"life_cycle_state": "RUNNING"}}, ] databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123") + mock_runs_api.get_run.assert_called_with("123", version="2.1") assert mock_runs_api.get_run.call_count == 2 mock_runs_api.reset_mock() @@ -290,7 +300,7 @@ def test_wait_for_terminating_task( {"state": {"life_cycle_state": "TERMINATED"}}, ] databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123") + mock_runs_api.get_run.assert_called_with("123", version="2.1") assert mock_runs_api.get_run.call_count == 3 mock_runs_api.reset_mock() @@ -305,7 +315,7 @@ def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_op {"state": {"life_cycle_state": "TERMINATED"}}, ] databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123") + mock_runs_api.get_run.assert_called_with("123", version="2.1") assert mock_runs_api.get_run.call_count == 3 mock_runs_api.reset_mock() @@ -328,15 +338,16 @@ def test_get_lifestyle_state(databricks_notebook_operator): "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client" ) @mock.patch( - "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_current_databricks_task" + "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id" ) def test_monitor_databricks_job_success( - mock_get_task_name, + mock_get_databricks_task_id, mock_get_api_client, mock_runs_api, mock_databricks_hook, databricks_notebook_operator, ): + mock_get_databricks_task_id.return_value = "1" # Define the expected response response = { "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1", @@ -354,8 +365,11 @@ def test_monitor_databricks_job_success( } mock_runs_api.return_value.get_run.return_value = response - databricks_notebook_operator.databricks_run_id = "1234" + databricks_notebook_operator.databricks_run_id = "1" databricks_notebook_operator.monitor_databricks_job() + mock_runs_api.return_value.get_run.assert_called_with( + databricks_notebook_operator.databricks_run_id, version="2.1" + ) @mock.patch("cosmos.providers.databricks.notebook.DatabricksHook") @@ -364,15 +378,16 @@ def test_monitor_databricks_job_success( "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client" ) @mock.patch( - "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_current_databricks_task" + "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id" ) def test_monitor_databricks_job_fail( - mock_get_task_name, + mock_get_databricks_task_id, mock_get_api_client, mock_runs_api, mock_databricks_hook, databricks_notebook_operator, ): + mock_get_databricks_task_id.return_value = "1" # Define the expected response response = { "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1", @@ -390,6 +405,6 @@ def test_monitor_databricks_job_fail( } mock_runs_api.return_value.get_run.return_value = response - databricks_notebook_operator.databricks_run_id = "1234" + databricks_notebook_operator.databricks_run_id = "1" with pytest.raises(AirflowException): databricks_notebook_operator.monitor_databricks_job() From 5c1375111244fa37178dcd3511498d07bf197d73 Mon Sep 17 00:00:00 2001 From: Pankaj Koti <pankajkoti699@gmail.com> Date: Tue, 28 Feb 2023 17:47:20 +0530 Subject: [PATCH 8/8] Address @tatiana's comment --- cosmos/providers/databricks/notebook.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py index 8035f7993..7ce16cad3 100644 --- a/cosmos/providers/databricks/notebook.py +++ b/cosmos/providers/databricks/notebook.py @@ -100,9 +100,9 @@ def convert_to_databricks_workflow_task( if hasattr(self.task_group, "notebook_packages"): self.notebook_packages.extend(self.task_group.notebook_packages) result = { - "task_key": self._get_databricks_task_id(), + "task_key": self._get_databricks_task_id(self.task_id), "depends_on": [ - {"task_key": self.dag_id + "__" + t.replace(".", "__")} + {"task_key": self._get_databricks_task_id(t)} for t in self.upstream_task_ids if t in relevant_upstreams ], @@ -118,9 +118,9 @@ def convert_to_databricks_workflow_task( } return result - def _get_databricks_task_id(self): + def _get_databricks_task_id(self, task_id: str): """Get the databricks task ID using dag_id and task_id. removes illegal characters.""" - return self.dag_id + "__" + self.task_id.replace(".", "__") + return self.dag_id + "__" + task_id.replace(".", "__") def monitor_databricks_job(self): """Monitor the Databricks job until it completes. Raises Airflow exception if the job fails.""" @@ -141,7 +141,7 @@ def _get_current_databricks_task(self, runs_api): for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ "tasks" ] - }[self._get_databricks_task_id()] + }[self._get_databricks_task_id(self.task_id)] def _handle_final_state(self, final_state): if final_state.get("life_cycle_state", None) != "TERMINATED": @@ -187,7 +187,7 @@ def launch_notebook_job(self): """Launch the notebook as a one-time job to Databricks.""" api_client = self._get_api_client() run_json = { - "run_name": self._get_databricks_task_id(), + "run_name": self._get_databricks_task_id(self.task_id), "notebook_task": { "notebook_path": self.notebook_path, "base_parameters": {"source": self.source},