From b606b18dc40563e9d1b0f47922f6c250f05c34d4 Mon Sep 17 00:00:00 2001 From: Bora Berke Sahin <67373739+boraberke@users.noreply.github.com> Date: Thu, 18 Jul 2024 22:09:20 +0300 Subject: [PATCH] Fix named parameters templating in Databricks operators (#40864) This PR fixes the many named parameters that was templated and was broken with #40471. The following operators are affected: DatabricksCreateJobsOperator DatabricksSubmitRunOperator DatabricksRunNowOperator closes: #40788 --- .../databricks/operators/databricks.py | 173 +++++++++++++----- .../databricks/operators/test_databricks.py | 156 ++++++++++++++++ 2 files changed, 284 insertions(+), 45 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index a263fa9106a11..b1299b9d85040 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -263,7 +263,23 @@ class DatabricksCreateJobsOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_fields: Sequence[str] = ( + "json", + "databricks_conn_id", + "name", + "description", + "tags", + "tasks", + "job_clusters", + "email_notifications", + "webhook_notifications", + "notification_settings", + "timeout_seconds", + "schedule", + "max_concurrent_runs", + "git_source", + "access_control_list", + ) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" ui_fgcolor = "#fff" @@ -300,21 +316,19 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args - self.overridden_json_params = { - "name": name, - "description": description, - "tags": tags, - "tasks": tasks, - "job_clusters": job_clusters, - "email_notifications": email_notifications, - "webhook_notifications": webhook_notifications, - "notification_settings": notification_settings, - "timeout_seconds": timeout_seconds, - "schedule": schedule, - "max_concurrent_runs": max_concurrent_runs, - "git_source": git_source, - "access_control_list": access_control_list, - } + self.name = name + self.description = description + self.tags = tags + self.tasks = tasks + self.job_clusters = job_clusters + self.email_notifications = email_notifications + self.webhook_notifications = webhook_notifications + self.notification_settings = notification_settings + self.timeout_seconds = timeout_seconds + self.schedule = schedule + self.max_concurrent_runs = max_concurrent_runs + self.git_source = git_source + self.access_control_list = access_control_list @cached_property def _hook(self): @@ -327,6 +341,22 @@ def _hook(self): ) def _setup_and_validate_json(self): + self.overridden_json_params = { + "name": self.name, + "description": self.description, + "tags": self.tags, + "tasks": self.tasks, + "job_clusters": self.job_clusters, + "email_notifications": self.email_notifications, + "webhook_notifications": self.webhook_notifications, + "notification_settings": self.notification_settings, + "timeout_seconds": self.timeout_seconds, + "schedule": self.schedule, + "max_concurrent_runs": self.max_concurrent_runs, + "git_source": self.git_source, + "access_control_list": self.access_control_list, + } + _handle_overridden_json_params(self) if "name" not in self.json: @@ -470,7 +500,25 @@ class DatabricksSubmitRunOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_fields: Sequence[str] = ( + "json", + "databricks_conn_id", + "tasks", + "spark_jar_task", + "notebook_task", + "spark_python_task", + "spark_submit_task", + "pipeline_task", + "dbt_task", + "new_cluster", + "existing_cluster_id", + "libraries", + "run_name", + "timeout_seconds", + "idempotency_token", + "access_control_list", + "git_source", + ) template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -516,23 +564,21 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable - self.overridden_json_params = { - "tasks": tasks, - "spark_jar_task": spark_jar_task, - "notebook_task": notebook_task, - "spark_python_task": spark_python_task, - "spark_submit_task": spark_submit_task, - "pipeline_task": pipeline_task, - "dbt_task": dbt_task, - "new_cluster": new_cluster, - "existing_cluster_id": existing_cluster_id, - "libraries": libraries, - "run_name": run_name, - "timeout_seconds": timeout_seconds, - "idempotency_token": idempotency_token, - "access_control_list": access_control_list, - "git_source": git_source, - } + self.tasks = tasks + self.spark_jar_task = spark_jar_task + self.notebook_task = notebook_task + self.spark_python_task = spark_python_task + self.spark_submit_task = spark_submit_task + self.pipeline_task = pipeline_task + self.dbt_task = dbt_task + self.new_cluster = new_cluster + self.existing_cluster_id = existing_cluster_id + self.libraries = libraries + self.run_name = run_name + self.timeout_seconds = timeout_seconds + self.idempotency_token = idempotency_token + self.access_control_list = access_control_list + self.git_source = git_source # This variable will be used in case our task gets killed. self.run_id: int | None = None @@ -552,6 +598,24 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def _setup_and_validate_json(self): + self.overridden_json_params = { + "tasks": self.tasks, + "spark_jar_task": self.spark_jar_task, + "notebook_task": self.notebook_task, + "spark_python_task": self.spark_python_task, + "spark_submit_task": self.spark_submit_task, + "pipeline_task": self.pipeline_task, + "dbt_task": self.dbt_task, + "new_cluster": self.new_cluster, + "existing_cluster_id": self.existing_cluster_id, + "libraries": self.libraries, + "run_name": self.run_name, + "timeout_seconds": self.timeout_seconds, + "idempotency_token": self.idempotency_token, + "access_control_list": self.access_control_list, + "git_source": self.git_source, + } + _handle_overridden_json_params(self) if "run_name" not in self.json or self.json["run_name"] is None: @@ -772,7 +836,18 @@ class DatabricksRunNowOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_fields: Sequence[str] = ( + "json", + "databricks_conn_id", + "job_id", + "job_name", + "notebook_params", + "python_params", + "python_named_params", + "jar_params", + "spark_submit_params", + "idempotency_token", + ) template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -815,16 +890,14 @@ def __init__( self.deferrable = deferrable self.repair_run = repair_run self.cancel_previous_runs = cancel_previous_runs - self.overridden_json_params = { - "job_id": job_id, - "job_name": job_name, - "notebook_params": notebook_params, - "python_params": python_params, - "python_named_params": python_named_params, - "jar_params": jar_params, - "spark_submit_params": spark_submit_params, - "idempotency_token": idempotency_token, - } + self.job_id = job_id + self.job_name = job_name + self.notebook_params = notebook_params + self.python_params = python_params + self.python_named_params = python_named_params + self.jar_params = jar_params + self.spark_submit_params = spark_submit_params + self.idempotency_token = idempotency_token # This variable will be used in case our task gets killed. self.run_id: int | None = None self.do_xcom_push = do_xcom_push @@ -843,6 +916,16 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def _setup_and_validate_json(self): + self.overridden_json_params = { + "job_id": self.job_id, + "job_name": self.job_name, + "notebook_params": self.notebook_params, + "python_params": self.python_params, + "python_named_params": self.python_named_params, + "jar_params": self.jar_params, + "spark_submit_params": self.spark_submit_params, + "idempotency_token": self.idempotency_token, + } _handle_overridden_json_params(self) if "job_id" in self.json and "job_name" in self.json: diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index ae2bb4976669c..a7337669047cb 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -66,7 +66,11 @@ RUN_ID = 1 RUN_PAGE_URL = "run-page-url" JOB_ID = "42" +TEMPLATED_JOB_ID = "job-id-{{ ds }}" +RENDERED_TEMPLATED_JOB_ID = f"job-id-{DATE}" JOB_NAME = "job-name" +TEMPLATED_JOB_NAME = "job-name-{{ ds }}" +RENDERED_TEMPLATED_JOB_NAME = f"job-name-{DATE}" JOB_DESCRIPTION = "job-description" NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] @@ -483,6 +487,68 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock.create_job.assert_called_once_with(expected) assert JOB_ID == tis[TASK_ID].xcom_pull() + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_templated_named_param(self, db_mock_class, dag_maker): + json = "{{ ti.xcom_pull(task_ids='push_json') }}" + with dag_maker("test_templated", render_template_as_native_obj=True): + push_json = PythonOperator( + task_id="push_json", + python_callable=lambda: { + "description": JOB_DESCRIPTION, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "notification_settings": NOTIFICATION_SETTINGS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + }, + ) + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json, name=TEMPLATED_JOB_NAME) + push_json >> op + + db_mock = db_mock_class.return_value + db_mock.create_job.return_value = JOB_ID + + db_mock.find_job_id_by_name.return_value = None + + dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils._normalise_json_content( + { + "name": RENDERED_TEMPLATED_JOB_NAME, + "description": JOB_DESCRIPTION, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "notification_settings": NOTIFICATION_SETTINGS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.create_job.assert_called_once_with(expected) + assert JOB_ID == tis[TASK_ID].xcom_pull() + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True): @@ -1015,6 +1081,50 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_called_once_with(RUN_ID) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_templated_named_params(self, db_mock_class, dag_maker): + json = "{{ ti.xcom_pull(task_ids='push_json') }}" + with dag_maker("test_templated", render_template_as_native_obj=True): + push_json = PythonOperator( + task_id="push_json", + python_callable=lambda: { + "new_cluster": NEW_CLUSTER, + }, + ) + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, json=json, notebook_task=TEMPLATED_NOTEBOOK_TASK + ) + push_json >> op + + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = RUN_ID + db_mock.get_run_page_url.return_value = RUN_PAGE_URL + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") + + dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils._normalise_json_content( + { + "new_cluster": NEW_CLUSTER, + "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK, + "run_name": TASK_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksSubmitRunOperator", + ) + + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True): @@ -1534,6 +1644,52 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_called_once_with(RUN_ID) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_templated_named_params(self, db_mock_class, dag_maker): + json = "{{ ti.xcom_pull(task_ids='push_json') }}" + with dag_maker("test_templated", render_template_as_native_obj=True): + push_json = PythonOperator( + task_id="push_json", + python_callable=lambda: { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + }, + ) + op = DatabricksRunNowOperator( + task_id=TASK_ID, job_id=TEMPLATED_JOB_ID, jar_params=TEMPLATED_JAR_PARAMS, json=json + ) + push_json >> op + + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + db_mock.get_run_page_url.return_value = RUN_PAGE_URL + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") + + dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils._normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": RENDERED_TEMPLATED_JAR_PARAMS, + "job_id": RENDERED_TEMPLATED_JOB_ID, + } + ) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True):