From a8870337c79d327d419718bd7706549df21fb1ee Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Sat, 25 Feb 2023 20:27:41 +0530
Subject: [PATCH 1/8] Print run

---
 cosmos/providers/databricks/notebook.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index a299f7ede..f478314e4 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -193,6 +193,7 @@ def launch_notebook_job(self):
         runs_api = RunsApi(api_client)
         run = runs_api.submit_run(run_json)
         self.databricks_run_id = run["run_id"]
+        print(run)
         return run
 
     def execute(self, context: Context) -> Any:

From 77417a39e16e7f2431cd09534b1ee4e98e6daf01 Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Sat, 25 Feb 2023 20:47:25 +0530
Subject: [PATCH 2/8] Additional prints

---
 cosmos/providers/databricks/notebook.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index f478314e4..8ba03768b 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -128,6 +128,7 @@ def monitor_databricks_job(self):
         self._handle_final_state(final_state)
 
     def _get_current_databricks_task(self, runs_api):
+        print(self.databricks_run_id, runs_api.get_run(self.databricks_run_id))
         return {
             x["task_key"]: x for x in runs_api.get_run(self.databricks_run_id)["tasks"]
         }[self.dag_id + "__" + self.task_id.replace(".", "__")]
@@ -193,7 +194,7 @@ def launch_notebook_job(self):
         runs_api = RunsApi(api_client)
         run = runs_api.submit_run(run_json)
         self.databricks_run_id = run["run_id"]
-        print(run)
+        print(run, self.databricks_run_id)
         return run
 
     def execute(self, context: Context) -> Any:

From f3c69b5deef8dd78cca57108b94d32d9ef052ea8 Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Sat, 25 Feb 2023 21:05:26 +0530
Subject: [PATCH 3/8] Pass API version to API call

---
 cosmos/providers/databricks/notebook.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index 8ba03768b..3312f5a2f 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -128,9 +128,13 @@ def monitor_databricks_job(self):
         self._handle_final_state(final_state)
 
     def _get_current_databricks_task(self, runs_api):
-        print(self.databricks_run_id, runs_api.get_run(self.databricks_run_id))
+        print(
+            self.databricks_run_id,
+            runs_api.get_run(self.databricks_run_id, version="2.1"),
+        )
         return {
-            x["task_key"]: x for x in runs_api.get_run(self.databricks_run_id)["tasks"]
+            x["task_key"]: x
+            for x in runs_api.get_run(self.databricks_run_id, version="2.1")["tasks"]
         }[self.dag_id + "__" + self.task_id.replace(".", "__")]
 
     def _handle_final_state(self, final_state):

From fe36ff2bb1ce44adcedaa2c32b7fefff74564217 Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Sat, 25 Feb 2023 21:27:46 +0530
Subject: [PATCH 4/8] Pass run-name so that it maps to task_key

---
 cosmos/providers/databricks/notebook.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index 3312f5a2f..172bef5c5 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -179,6 +179,7 @@ def launch_notebook_job(self):
         """Launch the notebook as a one-time job to Databricks."""
         api_client = self._get_api_client()
         run_json = {
+            "run_name": self.dag_id + "__" + self.task_id.replace(".", "__"),
             "notebook_task": {
                 "notebook_path": self.notebook_path,
                 "base_parameters": {"source": self.source},

From 862694cf9a68ca9e1057646025fad0200d6ac6af Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Sat, 25 Feb 2023 22:14:44 +0530
Subject: [PATCH 5/8] Add example DAG for single notebook launch

---
 cosmos/providers/databricks/constants.py      |  3 +
 cosmos/providers/databricks/notebook.py       | 11 ++--
 .../databricks/example_databricks_notebook.py | 61 +++++++++++++++++++
 3 files changed, 69 insertions(+), 6 deletions(-)
 create mode 100644 cosmos/providers/databricks/constants.py
 create mode 100644 examples/databricks/example_databricks_notebook.py

diff --git a/cosmos/providers/databricks/constants.py b/cosmos/providers/databricks/constants.py
new file mode 100644
index 000000000..13d81c9b9
--- /dev/null
+++ b/cosmos/providers/databricks/constants.py
@@ -0,0 +1,3 @@
+import os
+
+JOBS_API_VERSION = os.getenv("JOBS_API_VERSION", "2.1")
diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index 172bef5c5..f5837fac8 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -11,6 +11,8 @@
 from databricks_cli.runs.api import RunsApi
 from databricks_cli.sdk.api_client import ApiClient
 
+from cosmos.providers.databricks.constants import JOBS_API_VERSION
+
 
 class DatabricksNotebookOperator(BaseOperator):
     """
@@ -128,13 +130,11 @@ def monitor_databricks_job(self):
         self._handle_final_state(final_state)
 
     def _get_current_databricks_task(self, runs_api):
-        print(
-            self.databricks_run_id,
-            runs_api.get_run(self.databricks_run_id, version="2.1"),
-        )
         return {
             x["task_key"]: x
-            for x in runs_api.get_run(self.databricks_run_id, version="2.1")["tasks"]
+            for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
+                "tasks"
+            ]
         }[self.dag_id + "__" + self.task_id.replace(".", "__")]
 
     def _handle_final_state(self, final_state):
@@ -199,7 +199,6 @@ def launch_notebook_job(self):
         runs_api = RunsApi(api_client)
         run = runs_api.submit_run(run_json)
         self.databricks_run_id = run["run_id"]
-        print(run, self.databricks_run_id)
         return run
 
     def execute(self, context: Context) -> Any:
diff --git a/examples/databricks/example_databricks_notebook.py b/examples/databricks/example_databricks_notebook.py
new file mode 100644
index 000000000..beb96f5a0
--- /dev/null
+++ b/examples/databricks/example_databricks_notebook.py
@@ -0,0 +1,61 @@
+"""Example DAG for using the DatabricksNotebookOperator."""
+import os
+from datetime import timedelta
+
+from airflow.models.dag import DAG
+from airflow.utils.timezone import datetime
+
+from cosmos.providers.databricks.notebook import DatabricksNotebookOperator
+
+EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
+default_args = {
+    "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
+    "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
+    "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
+}
+
+DATABRICKS_CONN_ID = os.getenv("ASTRO_DATABRICKS_CONN_ID", "databricks_conn")
+NEW_CLUSTER_SPEC = {
+    "cluster_name": "",
+    "spark_version": "11.3.x-scala2.12",
+    "aws_attributes": {
+        "first_on_demand": 1,
+        "availability": "SPOT_WITH_FALLBACK",
+        "zone_id": "us-east-2b",
+        "spot_bid_price_percent": 100,
+        "ebs_volume_count": 0,
+    },
+    "node_type_id": "i3.xlarge",
+    "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"},
+    "enable_elastic_disk": False,
+    "data_security_mode": "LEGACY_SINGLE_USER_STANDARD",
+    "runtime_engine": "STANDARD",
+    "num_workers": 8,
+}
+
+dag = DAG(
+    dag_id="example_databricks_notebook",
+    start_date=datetime(2022, 1, 1),
+    schedule_interval=None,
+    catchup=False,
+    default_args=default_args,
+    tags=["example", "async", "databricks"],
+)
+with dag:
+    notebook_1 = DatabricksNotebookOperator(
+        task_id="notebook_1",
+        databricks_conn_id=DATABRICKS_CONN_ID,
+        notebook_path="/Shared/Notebook_1",
+        notebook_packages=[
+            {
+                "pypi": {
+                    "package": "simplejson==3.18.0",
+                    "repo": "https://pypi.org/simple",
+                }
+            },
+            {"pypi": {"package": "Faker"}},
+        ],
+        source="WORKSPACE",
+        job_cluster_key="random_cluster_key",
+        new_cluster=NEW_CLUSTER_SPEC,
+    )

From af3f3c11de946aea50252ce400931a5d18c6e4a4 Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Mon, 27 Feb 2023 22:12:25 +0530
Subject: [PATCH 6/8] Fix Databricks example DAG path in docs

---
 docs/databricks/workflows.rst | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/databricks/workflows.rst b/docs/databricks/workflows.rst
index f2140fbc1..94717ad9b 100644
--- a/docs/databricks/workflows.rst
+++ b/docs/databricks/workflows.rst
@@ -12,7 +12,7 @@ The DatabricksWorkflowTaskGroup is designed to look and function like a standard
 with the added ability to include specific Databricks arguments.
 An example of how to use the DatabricksWorkflowTaskGroup can be seen in the following code snippet:
 
-.. exampleinclude:: /../astronomer/providers/databricks/example_dags/example_databricks_workflow.py
+.. literalinclude:: /../examples/databricks/example_databricks_workflow.py
     :language: python
     :dedent: 4
     :start-after: [START howto_databricks_workflow_notebook]

From efda95afad4c8f5ca24654e9e225948bc940ef87 Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Tue, 28 Feb 2023 15:00:08 +0530
Subject: [PATCH 7/8] Use latest version across all get_run calls and add tests

---
 cosmos/providers/databricks/notebook.py | 18 ++++++++----
 tests/databricks/test_notebook.py       | 37 +++++++++++++++++--------
 2 files changed, 39 insertions(+), 16 deletions(-)

diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index f5837fac8..8035f7993 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -100,7 +100,7 @@ def convert_to_databricks_workflow_task(
         if hasattr(self.task_group, "notebook_packages"):
             self.notebook_packages.extend(self.task_group.notebook_packages)
         result = {
-            "task_key": self.dag_id + "__" + self.task_id.replace(".", "__"),
+            "task_key": self._get_databricks_task_id(),
             "depends_on": [
                 {"task_key": self.dag_id + "__" + t.replace(".", "__")}
                 for t in self.upstream_task_ids
@@ -118,6 +118,10 @@ def convert_to_databricks_workflow_task(
         }
         return result
 
+    def _get_databricks_task_id(self):
+        """Get the databricks task ID using dag_id and task_id. removes illegal characters."""
+        return self.dag_id + "__" + self.task_id.replace(".", "__")
+
     def monitor_databricks_job(self):
         """Monitor the Databricks job until it completes. Raises Airflow exception if the job fails."""
         api_client = self._get_api_client()
@@ -126,7 +130,9 @@ def monitor_databricks_job(self):
         self._wait_for_pending_task(current_task, runs_api)
         self._wait_for_running_task(current_task, runs_api)
         self._wait_for_terminating_task(current_task, runs_api)
-        final_state = runs_api.get_run(current_task["run_id"])["state"]
+        final_state = runs_api.get_run(
+            current_task["run_id"], version=JOBS_API_VERSION
+        )["state"]
         self._handle_final_state(final_state)
 
     def _get_current_databricks_task(self, runs_api):
@@ -135,7 +141,7 @@ def _get_current_databricks_task(self, runs_api):
             for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
                 "tasks"
             ]
-        }[self.dag_id + "__" + self.task_id.replace(".", "__")]
+        }[self._get_databricks_task_id()]
 
     def _handle_final_state(self, final_state):
         if final_state.get("life_cycle_state", None) != "TERMINATED":
@@ -150,7 +156,9 @@ def _handle_final_state(self, final_state):
             )
 
     def _get_lifestyle_state(self, current_task, runs_api):
-        return runs_api.get_run(current_task["run_id"])["state"]["life_cycle_state"]
+        return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[
+            "state"
+        ]["life_cycle_state"]
 
     def _wait_on_state(self, current_task, runs_api, state):
         while self._get_lifestyle_state(current_task, runs_api) == state:
@@ -179,7 +187,7 @@ def launch_notebook_job(self):
         """Launch the notebook as a one-time job to Databricks."""
         api_client = self._get_api_client()
         run_json = {
-            "run_name": self.dag_id + "__" + self.task_id.replace(".", "__"),
+            "run_name": self._get_databricks_task_id(),
             "notebook_task": {
                 "notebook_path": self.notebook_path,
                 "base_parameters": {"source": self.source},
diff --git a/tests/databricks/test_notebook.py b/tests/databricks/test_notebook.py
index f5403b06b..1c4044241 100644
--- a/tests/databricks/test_notebook.py
+++ b/tests/databricks/test_notebook.py
@@ -108,13 +108,17 @@ def test_databricks_notebook_operator_with_taskgroup(
 @mock.patch(
     "cosmos.providers.databricks.notebook.DatabricksNotebookOperator.monitor_databricks_job"
 )
+@mock.patch(
+    "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
+)
 @mock.patch(
     "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
 )
 @mock.patch("cosmos.providers.databricks.notebook.RunsApi")
 def test_databricks_notebook_operator_without_taskgroup_new_cluster(
-    mock_runs_api, mock_api_client, mock_monitor, dag
+    mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag
 ):
+    mock_get_databricks_task_id.return_value = "1234"
     mock_runs_api.return_value = mock.MagicMock()
     with dag:
         DatabricksNotebookOperator(
@@ -132,6 +136,7 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster(
     dag.test()
     mock_runs_api.return_value.submit_run.assert_called_once_with(
         {
+            "run_name": "1234",
             "notebook_task": {
                 "notebook_path": "/foo/bar",
                 "base_parameters": {"source": "WORKSPACE"},
@@ -146,13 +151,17 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster(
 @mock.patch(
     "cosmos.providers.databricks.notebook.DatabricksNotebookOperator.monitor_databricks_job"
 )
+@mock.patch(
+    "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
+)
 @mock.patch(
     "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
 )
 @mock.patch("cosmos.providers.databricks.notebook.RunsApi")
 def test_databricks_notebook_operator_without_taskgroup_existing_cluster(
-    mock_runs_api, mock_api_client, mock_monitor, dag
+    mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag
 ):
+    mock_get_databricks_task_id.return_value = "1234"
     mock_runs_api.return_value = mock.MagicMock()
     with dag:
         DatabricksNotebookOperator(
@@ -170,6 +179,7 @@ def test_databricks_notebook_operator_without_taskgroup_existing_cluster(
     dag.test()
     mock_runs_api.return_value.submit_run.assert_called_once_with(
         {
+            "run_name": "1234",
             "notebook_task": {
                 "notebook_path": "/foo/bar",
                 "base_parameters": {"source": "WORKSPACE"},
@@ -273,7 +283,7 @@ def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_op
         {"state": {"life_cycle_state": "RUNNING"}},
     ]
     databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api)
-    mock_runs_api.get_run.assert_called_with("123")
+    mock_runs_api.get_run.assert_called_with("123", version="2.1")
     assert mock_runs_api.get_run.call_count == 2
     mock_runs_api.reset_mock()
 
@@ -290,7 +300,7 @@ def test_wait_for_terminating_task(
         {"state": {"life_cycle_state": "TERMINATED"}},
     ]
     databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api)
-    mock_runs_api.get_run.assert_called_with("123")
+    mock_runs_api.get_run.assert_called_with("123", version="2.1")
     assert mock_runs_api.get_run.call_count == 3
     mock_runs_api.reset_mock()
 
@@ -305,7 +315,7 @@ def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_op
         {"state": {"life_cycle_state": "TERMINATED"}},
     ]
     databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api)
-    mock_runs_api.get_run.assert_called_with("123")
+    mock_runs_api.get_run.assert_called_with("123", version="2.1")
     assert mock_runs_api.get_run.call_count == 3
     mock_runs_api.reset_mock()
 
@@ -328,15 +338,16 @@ def test_get_lifestyle_state(databricks_notebook_operator):
     "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
 )
 @mock.patch(
-    "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_current_databricks_task"
+    "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
 )
 def test_monitor_databricks_job_success(
-    mock_get_task_name,
+    mock_get_databricks_task_id,
     mock_get_api_client,
     mock_runs_api,
     mock_databricks_hook,
     databricks_notebook_operator,
 ):
+    mock_get_databricks_task_id.return_value = "1"
     # Define the expected response
     response = {
         "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1",
@@ -354,8 +365,11 @@ def test_monitor_databricks_job_success(
     }
     mock_runs_api.return_value.get_run.return_value = response
 
-    databricks_notebook_operator.databricks_run_id = "1234"
+    databricks_notebook_operator.databricks_run_id = "1"
     databricks_notebook_operator.monitor_databricks_job()
+    mock_runs_api.return_value.get_run.assert_called_with(
+        databricks_notebook_operator.databricks_run_id, version="2.1"
+    )
 
 
 @mock.patch("cosmos.providers.databricks.notebook.DatabricksHook")
@@ -364,15 +378,16 @@ def test_monitor_databricks_job_success(
     "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_api_client"
 )
 @mock.patch(
-    "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_current_databricks_task"
+    "cosmos.providers.databricks.notebook.DatabricksNotebookOperator._get_databricks_task_id"
 )
 def test_monitor_databricks_job_fail(
-    mock_get_task_name,
+    mock_get_databricks_task_id,
     mock_get_api_client,
     mock_runs_api,
     mock_databricks_hook,
     databricks_notebook_operator,
 ):
+    mock_get_databricks_task_id.return_value = "1"
     # Define the expected response
     response = {
         "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1",
@@ -390,6 +405,6 @@ def test_monitor_databricks_job_fail(
     }
     mock_runs_api.return_value.get_run.return_value = response
 
-    databricks_notebook_operator.databricks_run_id = "1234"
+    databricks_notebook_operator.databricks_run_id = "1"
     with pytest.raises(AirflowException):
         databricks_notebook_operator.monitor_databricks_job()

From 5c1375111244fa37178dcd3511498d07bf197d73 Mon Sep 17 00:00:00 2001
From: Pankaj Koti <pankajkoti699@gmail.com>
Date: Tue, 28 Feb 2023 17:47:20 +0530
Subject: [PATCH 8/8] Address @tatiana's comment

---
 cosmos/providers/databricks/notebook.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/cosmos/providers/databricks/notebook.py b/cosmos/providers/databricks/notebook.py
index 8035f7993..7ce16cad3 100644
--- a/cosmos/providers/databricks/notebook.py
+++ b/cosmos/providers/databricks/notebook.py
@@ -100,9 +100,9 @@ def convert_to_databricks_workflow_task(
         if hasattr(self.task_group, "notebook_packages"):
             self.notebook_packages.extend(self.task_group.notebook_packages)
         result = {
-            "task_key": self._get_databricks_task_id(),
+            "task_key": self._get_databricks_task_id(self.task_id),
             "depends_on": [
-                {"task_key": self.dag_id + "__" + t.replace(".", "__")}
+                {"task_key": self._get_databricks_task_id(t)}
                 for t in self.upstream_task_ids
                 if t in relevant_upstreams
             ],
@@ -118,9 +118,9 @@ def convert_to_databricks_workflow_task(
         }
         return result
 
-    def _get_databricks_task_id(self):
+    def _get_databricks_task_id(self, task_id: str):
         """Get the databricks task ID using dag_id and task_id. removes illegal characters."""
-        return self.dag_id + "__" + self.task_id.replace(".", "__")
+        return self.dag_id + "__" + task_id.replace(".", "__")
 
     def monitor_databricks_job(self):
         """Monitor the Databricks job until it completes. Raises Airflow exception if the job fails."""
@@ -141,7 +141,7 @@ def _get_current_databricks_task(self, runs_api):
             for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
                 "tasks"
             ]
-        }[self._get_databricks_task_id()]
+        }[self._get_databricks_task_id(self.task_id)]
 
     def _handle_final_state(self, final_state):
         if final_state.get("life_cycle_state", None) != "TERMINATED":
@@ -187,7 +187,7 @@ def launch_notebook_job(self):
         """Launch the notebook as a one-time job to Databricks."""
         api_client = self._get_api_client()
         run_json = {
-            "run_name": self._get_databricks_task_id(),
+            "run_name": self._get_databricks_task_id(self.task_id),
             "notebook_task": {
                 "notebook_path": self.notebook_path,
                 "base_parameters": {"source": self.source},