From 7472909f2a6b993d56b6baac72054a4d5f9e0c59 Mon Sep 17 00:00:00 2001 From: boraberke Date: Thu, 27 Jun 2024 12:31:05 +0300 Subject: [PATCH 1/9] move assignments of `json` fields to the `execute` function --- .../databricks/operators/databricks.py | 43 ++--- .../databricks/operators/test_databricks.py | 147 +++++++++++++++--- 2 files changed, 146 insertions(+), 44 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 38ce8e332455a..5fe32885eb32f 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -806,27 +806,16 @@ def __init__( self.deferrable = deferrable self.repair_run = repair_run self.cancel_previous_runs = cancel_previous_runs - - if job_id is not None: - self.json["job_id"] = job_id - if job_name is not None: - self.json["job_name"] = job_name - if "job_id" in self.json and "job_name" in self.json: - raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") - if notebook_params is not None: - self.json["notebook_params"] = notebook_params - if python_params is not None: - self.json["python_params"] = python_params - if python_named_params is not None: - self.json["python_named_params"] = python_named_params - if jar_params is not None: - self.json["jar_params"] = jar_params - if spark_submit_params is not None: - self.json["spark_submit_params"] = spark_submit_params - if idempotency_token is not None: - self.json["idempotency_token"] = idempotency_token - if self.json: - self.json = normalise_json_content(self.json) + self._overridden_json_params = { + "job_id": job_id, + "job_name": job_name, + "notebook_params": notebook_params, + "python_params": python_params, + "python_named_params": python_named_params, + "jar_params": jar_params, + "spark_submit_params": spark_submit_params, + "idempotency_token": idempotency_token, + } # This variable will be used in case our task gets killed. self.run_id: int | None = None self.do_xcom_push = do_xcom_push @@ -844,7 +833,19 @@ def _get_hook(self, caller: str) -> DatabricksHook: caller=caller, ) + def _setup_and_validate_json(self): + for key, value in self._overridden_json_params.items(): + if value is not None: + self.json[key] = value + + if "job_id" in self.json and "job_name" in self.json: + raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") + + if self.json: + self.json = normalise_json_content(self.json) + def execute(self, context: Context): + self._setup_and_validate_json() hook = self._hook if "job_name" in self.json: job_id = hook.find_job_id_by_name(self.json["job_name"]) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 7ff2295eda94a..ebb9d1dd6c1fd 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -23,8 +23,10 @@ import pytest +from airflow.decorators import task from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG +from airflow.operators.python import PythonOperator from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, @@ -36,6 +38,7 @@ ) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger from airflow.providers.databricks.utils import databricks as utils +from airflow.utils import timezone pytestmark = pytest.mark.db_test @@ -1125,18 +1128,20 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo class TestDatabricksRunNowOperator: - def test_init_with_named_parameters(self): + def test_validate_json_with_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with named parameters. """ op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID) + op._setup_and_validate_json() + expected = utils.normalise_json_content({"job_id": 42}) assert expected == op.json - def test_init_with_json(self): + def test_validate_json_with_json(self): """ - Test the initializer with json data. + Test the _setup_and_validate_json function with json data. """ json = { "notebook_params": NOTEBOOK_PARAMS, @@ -1147,6 +1152,7 @@ def test_init_with_json(self): "repair_run": False, } op = DatabricksRunNowOperator(task_id=TASK_ID, json=json) + op._setup_and_validate_json() expected = utils.normalise_json_content( { @@ -1161,9 +1167,9 @@ def test_init_with_json(self): assert expected == op.json - def test_init_with_merging(self): + def test_validate_json_with_merging(self): """ - Test the initializer when json and other named parameters are both + Test the _setup_and_validate_json function when json and other named parameters are both provided. The named parameters should override top level keys in the json dict. """ @@ -1180,6 +1186,7 @@ def test_init_with_merging(self): jar_params=override_jar_params, spark_submit_params=SPARK_SUBMIT_PARAMS, ) + op._setup_and_validate_json() expected = utils.normalise_json_content( { @@ -1194,12 +1201,13 @@ def test_init_with_merging(self): assert expected == op.json @pytest.mark.db_test - def test_init_with_templating(self): + def test_validate_json_with_templating(self): json = {"notebook_params": NOTEBOOK_PARAMS, "jar_params": TEMPLATED_JAR_PARAMS} dag = DAG("test", start_date=datetime.now()) op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) op.render_template_fields(context={"ds": DATE}) + op._setup_and_validate_json() expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, @@ -1209,7 +1217,7 @@ def test_init_with_templating(self): ) assert expected == op.json - def test_init_with_bad_type(self): + def test_validate_json_with_bad_type(self): json = {"test": datetime.now()} # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( @@ -1217,7 +1225,114 @@ def test_init_with_bad_type(self): r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) + DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)._setup_and_validate_json() + + def test_validate_json_exception_with_job_name_and_job_id(self): + exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" + + with pytest.raises(AirflowException, match=exception_message): + DatabricksRunNowOperator( + task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME + )._setup_and_validate_json() + + run = {"job_id": JOB_ID, "job_name": JOB_NAME} + with pytest.raises(AirflowException, match=exception_message): + DatabricksRunNowOperator(task_id=TASK_ID, json=run)._setup_and_validate_json() + + run = {"job_id": JOB_ID} + with pytest.raises(AirflowException, match=exception_message): + DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME)._setup_and_validate_json() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): + json = "{{ ti.xcom_pull(task_ids='push_json') }}" + with dag_maker("test_chime_notifier", render_template_as_native_obj=True) as dag: + push_json = PythonOperator( + task_id="push_json", + python_callable=lambda: { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + }, + ) + op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) + push_json >> op + + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") + + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): + with dag_maker("test_xcomarg", render_template_as_native_obj=True) as dag: + + @task + def push_json() -> dict[str, str]: + return { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + + json = push_json() + + op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) + + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") + + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_success(self, db_mock_class): @@ -1486,20 +1601,6 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_not_called() - def test_init_exception_with_job_name_and_job_id(self): - exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" - - with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME) - - run = {"job_id": JOB_ID, "job_name": JOB_NAME} - with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, json=run) - - run = {"job_id": JOB_ID} - with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME) - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_with_job_name(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} From 0ee66816ea8ffc35e7b340f3219c62b0aecc8907 Mon Sep 17 00:00:00 2001 From: boraberke Date: Thu, 27 Jun 2024 13:12:03 +0300 Subject: [PATCH 2/9] fix dag test name --- tests/providers/databricks/operators/test_databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index ebb9d1dd6c1fd..e22889a55c841 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1246,7 +1246,7 @@ def test_validate_json_exception_with_job_name_and_job_id(self): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_chime_notifier", render_template_as_native_obj=True) as dag: + with dag_maker("test_templated", render_template_as_native_obj=True) as dag: push_json = PythonOperator( task_id="push_json", python_callable=lambda: { From 6f80f7f822601b81fa78f581f9c1c715fd1807e0 Mon Sep 17 00:00:00 2001 From: boraberke Date: Thu, 27 Jun 2024 13:17:00 +0300 Subject: [PATCH 3/9] move assignments of `json` fields to the `execute` function --- .../databricks/operators/databricks.py | 75 ++++---- .../databricks/operators/test_databricks.py | 173 ++++++++++++++---- 2 files changed, 177 insertions(+), 71 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 5fe32885eb32f..cefaea9eb345f 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -505,43 +505,23 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable - if tasks is not None: - self.json["tasks"] = tasks - if spark_jar_task is not None: - self.json["spark_jar_task"] = spark_jar_task - if notebook_task is not None: - self.json["notebook_task"] = notebook_task - if spark_python_task is not None: - self.json["spark_python_task"] = spark_python_task - if spark_submit_task is not None: - self.json["spark_submit_task"] = spark_submit_task - if pipeline_task is not None: - self.json["pipeline_task"] = pipeline_task - if dbt_task is not None: - self.json["dbt_task"] = dbt_task - if new_cluster is not None: - self.json["new_cluster"] = new_cluster - if existing_cluster_id is not None: - self.json["existing_cluster_id"] = existing_cluster_id - if libraries is not None: - self.json["libraries"] = libraries - if run_name is not None: - self.json["run_name"] = run_name - if timeout_seconds is not None: - self.json["timeout_seconds"] = timeout_seconds - if "run_name" not in self.json: - self.json["run_name"] = run_name or kwargs["task_id"] - if idempotency_token is not None: - self.json["idempotency_token"] = idempotency_token - if access_control_list is not None: - self.json["access_control_list"] = access_control_list - if git_source is not None: - self.json["git_source"] = git_source - - if "dbt_task" in self.json and "git_source" not in self.json: - raise AirflowException("git_source is required for dbt_task") - if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task: - raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'") + self._overridden_json_params = { + "tasks": tasks, + "spark_jar_task": spark_jar_task, + "notebook_task": notebook_task, + "spark_python_task": spark_python_task, + "spark_submit_task": spark_submit_task, + "pipeline_task": pipeline_task, + "dbt_task": dbt_task, + "new_cluster": new_cluster, + "existing_cluster_id": existing_cluster_id, + "libraries": libraries, + "run_name": run_name, + "timeout_seconds": timeout_seconds, + "idempotency_token": idempotency_token, + "access_control_list": access_control_list, + "git_source": git_source, + } # This variable will be used in case our task gets killed. self.run_id: int | None = None @@ -560,7 +540,28 @@ def _get_hook(self, caller: str) -> DatabricksHook: caller=caller, ) + def _setup_and_validate_json(self): + for key, value in self._overridden_json_params.items(): + if value is not None: + self.json[key] = value + + if "run_name" not in self.json or self.json["run_name"] is None: + self.json["run_name"] = self.task_id + + if "dbt_task" in self.json and "git_source" not in self.json: + raise AirflowException("git_source is required for dbt_task") + if ( + "pipeline_task" in self.json + and "pipeline_id" in self.json["pipeline_task"] + and "pipeline_name" in self.json["pipeline_task"] + ): + raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'") + + if self.json: + self.json = normalise_json_content(self.json) + def execute(self, context: Context): + self._setup_and_validate_json() if ( "pipeline_task" in self.json and self.json["pipeline_task"].get("pipeline_id") is None diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index e22889a55c841..c93a65e59e749 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -589,66 +589,76 @@ def test_exec_update_job_permission_with_empty_acl(self, db_mock_class): class TestDatabricksSubmitRunOperator: - def test_init_with_notebook_task_named_parameters(self): + def test_validate_json_with_notebook_task_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with named parameters. """ op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK ) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_spark_python_task_named_parameters(self): + def test_validate_json_with_spark_python_task_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with the named parameters. """ op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_python_task=SPARK_PYTHON_TASK ) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "spark_python_task": SPARK_PYTHON_TASK, "run_name": TASK_ID} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_pipeline_name_task_named_parameters(self): + def test_validate_json_with_pipeline_name_task_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with the named parameters. """ op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_NAME_TASK) + op._setup_and_validate_json() + expected = utils.normalise_json_content({"pipeline_task": PIPELINE_NAME_TASK, "run_name": TASK_ID}) assert expected == utils.normalise_json_content(op.json) - def test_init_with_pipeline_id_task_named_parameters(self): + def test_validate_json_with_pipeline_id_task_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with the named parameters. """ op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_ID_TASK) + op._setup_and_validate_json() + expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) assert expected == utils.normalise_json_content(op.json) - def test_init_with_spark_submit_task_named_parameters(self): + def test_validate_json_with_spark_submit_task_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with the named parameters. """ op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_submit_task=SPARK_SUBMIT_TASK ) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "spark_submit_task": SPARK_SUBMIT_TASK, "run_name": TASK_ID} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_dbt_task_named_parameters(self): + def test_validate_json_with_dbt_task_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with the named parameters. """ git_source = { "git_url": "https://github.com/dbt-labs/jaffle_shop", @@ -658,15 +668,17 @@ def test_init_with_dbt_task_named_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK, git_source=git_source ) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_dbt_task_mixed_parameters(self): + def test_validate_json_with_dbt_task_mixed_parameters(self): """ - Test the initializer with mixed parameters. + Test the _setup_and_validate_json function with mixed parameters. """ git_source = { "git_url": "https://github.com/dbt-labs/jaffle_shop", @@ -677,73 +689,85 @@ def test_init_with_dbt_task_mixed_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK, json=json ) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_dbt_task_without_git_source_raises_error(self): + def test_validate_json_with_dbt_task_without_git_source_raises_error(self): """ - Test the initializer without the necessary git_source for dbt_task raises error. + Test the _setup_and_validate_json function without the necessary git_source for dbt_task raises error. """ exception_message = "git_source is required for dbt_task" with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK) + DatabricksSubmitRunOperator( + task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK + )._setup_and_validate_json() - def test_init_with_dbt_task_json_without_git_source_raises_error(self): + def test_validate_json_with_dbt_task_json_without_git_source_raises_error(self): """ - Test the initializer without the necessary git_source for dbt_task raises error. + Test the _setup_and_validate_json function without the necessary git_source for dbt_task raises error. """ json = {"dbt_task": DBT_TASK, "new_cluster": NEW_CLUSTER} exception_message = "git_source is required for dbt_task" with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() - def test_init_with_json(self): + def test_validate_json_with_json(self): """ - Test the initializer with json data. + Test the _setup_and_validate_json function with json data. """ json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_tasks(self): + def test_validate_json_with_tasks(self): tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}] op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks) + op._setup_and_validate_json() + expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks": tasks}) assert expected == utils.normalise_json_content(op.json) - def test_init_with_specified_run_name(self): + def test_validate_json_with_specified_run_name(self): """ - Test the initializer with a specified run_name. + Test the _setup_and_validate_json function with a specified run_name. """ json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} ) assert expected == utils.normalise_json_content(op.json) - def test_pipeline_task(self): + def test_validate_json_with_pipeline_task(self): """ - Test the initializer with a pipeline task. + Test the _setup_and_validate_json function with a pipeline task. """ pipeline_task = {"pipeline_id": "test-dlt"} json = {"new_cluster": NEW_CLUSTER, "run_name": RUN_NAME, "pipeline_task": pipeline_task} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + op._setup_and_validate_json() + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task, "run_name": RUN_NAME} ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_merging(self): + def test_validate_json_with_merging(self): """ - Test the initializer when json and other named parameters are both + Test the _setup_and_validate_json function when json and other named parameters are both provided. The named parameters should override top level keys in the json dict. """ @@ -753,6 +777,8 @@ def test_init_with_merging(self): "notebook_task": NOTEBOOK_TASK, } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) + op._setup_and_validate_json() + expected = utils.normalise_json_content( { "new_cluster": override_new_cluster, @@ -763,13 +789,15 @@ def test_init_with_merging(self): assert expected == utils.normalise_json_content(op.json) @pytest.mark.db_test - def test_init_with_templating(self): + def test_validate_json_with_templating(self): json = { "new_cluster": NEW_CLUSTER, "notebook_task": TEMPLATED_NOTEBOOK_TASK, } dag = DAG("test", start_date=datetime.now()) op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json) + op._setup_and_validate_json() + op.render_template_fields(context={"ds": DATE}) expected = utils.normalise_json_content( { @@ -780,7 +808,7 @@ def test_init_with_templating(self): ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_git_source(self): + def test_validate_json_with_git_source(self): json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} git_source = { "git_url": "https://github.com/apache/airflow", @@ -788,6 +816,8 @@ def test_init_with_git_source(self): "git_branch": "main", } op = DatabricksSubmitRunOperator(task_id=TASK_ID, git_source=git_source, json=json) + op._setup_and_validate_json() + expected = utils.normalise_json_content( { "new_cluster": NEW_CLUSTER, @@ -798,16 +828,91 @@ def test_init_with_git_source(self): ) assert expected == utils.normalise_json_content(op.json) - def test_init_with_bad_type(self): + def test_validate_json_with_bad_type(self): json = {"test": datetime.now()} - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( r"Type \<(type|class) \'datetime.datetime\'\> used " r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - utils.normalise_json_content(op.json) + DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): + json = "{{ ti.xcom_pull(task_ids='push_json') }}" + with dag_maker("test_templated", render_template_as_native_obj=True) as dag: + push_json = PythonOperator( + task_id="push_json", + python_callable=lambda: { + "new_cluster": NEW_CLUSTER, + "notebook_task": NOTEBOOK_TASK, + }, + ) + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + push_json >> op + + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = RUN_ID + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") + + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils.normalise_json_content( + {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksSubmitRunOperator", + ) + + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): + with dag_maker("test_xcomarg", render_template_as_native_obj=True) as dag: + + @task + def push_json() -> dict: + return { + "new_cluster": NEW_CLUSTER, + "notebook_task": NOTEBOOK_TASK, + } + + json = push_json() + + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = RUN_ID + db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") + + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils.normalise_json_content( + {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksSubmitRunOperator", + ) + + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_success(self, db_mock_class): From c8c3912d9182762444ce5c8fb368f36d815dcf17 Mon Sep 17 00:00:00 2001 From: boraberke Date: Thu, 27 Jun 2024 22:07:40 +0300 Subject: [PATCH 4/9] DatabricksCreateJobsOperator: move assignments of `json` fields --- .../databricks/operators/databricks.py | 56 +++---- .../databricks/operators/test_databricks.py | 158 ++++++++++++++++-- 2 files changed, 175 insertions(+), 39 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index cefaea9eb345f..38e51ef105878 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -284,34 +284,21 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args - if name is not None: - self.json["name"] = name - if description is not None: - self.json["description"] = description - if tags is not None: - self.json["tags"] = tags - if tasks is not None: - self.json["tasks"] = tasks - if job_clusters is not None: - self.json["job_clusters"] = job_clusters - if email_notifications is not None: - self.json["email_notifications"] = email_notifications - if webhook_notifications is not None: - self.json["webhook_notifications"] = webhook_notifications - if notification_settings is not None: - self.json["notification_settings"] = notification_settings - if timeout_seconds is not None: - self.json["timeout_seconds"] = timeout_seconds - if schedule is not None: - self.json["schedule"] = schedule - if max_concurrent_runs is not None: - self.json["max_concurrent_runs"] = max_concurrent_runs - if git_source is not None: - self.json["git_source"] = git_source - if access_control_list is not None: - self.json["access_control_list"] = access_control_list - if self.json: - self.json = normalise_json_content(self.json) + self._overridden_json_params = { + "name": name, + "description": description, + "tags": tags, + "tasks": tasks, + "job_clusters": job_clusters, + "email_notifications": email_notifications, + "webhook_notifications": webhook_notifications, + "notification_settings": notification_settings, + "timeout_seconds": timeout_seconds, + "schedule": schedule, + "max_concurrent_runs": max_concurrent_runs, + "git_source": git_source, + "access_control_list": access_control_list, + } @cached_property def _hook(self): @@ -323,9 +310,20 @@ def _hook(self): caller="DatabricksCreateJobsOperator", ) - def execute(self, context: Context) -> int: + def _setup_and_validate_json(self): + for key, value in self._overridden_json_params.items(): + if value is not None: + self.json[key] = value + if "name" not in self.json: raise AirflowException("Missing required parameter: name") + + if self.json: + self.json = normalise_json_content(self.json) + + def execute(self, context: Context) -> int: + self._setup_and_validate_json() + job_id = self._hook.find_job_id_by_name(self.json["name"]) if job_id is None: return self._hook.create_job(self.json) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index c93a65e59e749..6f213eedee5e1 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -251,9 +251,9 @@ def make_run_with_state_mock( class TestDatabricksCreateJobsOperator: - def test_init_with_named_parameters(self): + def test_validate_json_with_named_parameters(self): """ - Test the initializer with the named parameters. + Test the _setup_and_validate_json function with the named parameters. """ op = DatabricksCreateJobsOperator( task_id=TASK_ID, @@ -269,6 +269,8 @@ def test_init_with_named_parameters(self): git_source=GIT_SOURCE, access_control_list=ACCESS_CONTROL_LIST, ) + op._setup_and_validate_json() + expected = utils.normalise_json_content( { "name": JOB_NAME, @@ -287,9 +289,9 @@ def test_init_with_named_parameters(self): assert expected == op.json - def test_init_with_json(self): + def test_validate_json_with_json(self): """ - Test the initializer with json data. + Test the _setup_and_validate_json function with json data. """ json = { "name": JOB_NAME, @@ -305,6 +307,7 @@ def test_init_with_json(self): "access_control_list": ACCESS_CONTROL_LIST, } op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + op._setup_and_validate_json() expected = utils.normalise_json_content( { @@ -324,9 +327,9 @@ def test_init_with_json(self): assert expected == op.json - def test_init_with_merging(self): + def test_validate_json_with_merging(self): """ - Test the initializer when json and other named parameters are both + Test the _setup_and_validate_json function when json and other named parameters are both provided. The named parameters should override top level keys in the json dict. """ @@ -370,6 +373,7 @@ def test_init_with_merging(self): git_source=override_git_source, access_control_list=override_access_control_list, ) + op._setup_and_validate_json() expected = utils.normalise_json_content( { @@ -389,24 +393,158 @@ def test_init_with_merging(self): assert expected == op.json - def test_init_with_templating(self): + def test_validate_json_with_templating(self): json = {"name": "test-{{ ds }}"} dag = DAG("test", start_date=datetime.now()) op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={"ds": DATE}) + op._setup_and_validate_json() + expected = utils.normalise_json_content({"name": f"test-{DATE}"}) assert expected == op.json - def test_init_with_bad_type(self): - json = {"test": datetime.now()} + def test_validate_json_with_bad_type(self): + json = {"test": datetime.now(), "name": "test"} # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( r"Type \<(type|class) \'datetime.datetime\'\> used " r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() + + def test_validate_json_with_no_name(self): + json = {} + exception_message = "Missing required parameter: name" + with pytest.raises(AirflowException, match=exception_message): + DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): + json = "{{ ti.xcom_pull(task_ids='push_json') }}" + with dag_maker("test_templated", render_template_as_native_obj=True): + push_json = PythonOperator( + task_id="push_json", + python_callable=lambda: { + "name": JOB_NAME, + "description": JOB_DESCRIPTION, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "notification_settings": NOTIFICATION_SETTINGS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + }, + ) + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + push_json >> op + + db_mock = db_mock_class.return_value + db_mock.create_job.return_value = JOB_ID + + db_mock.find_job_id_by_name.return_value = None + + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "description": JOB_DESCRIPTION, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "notification_settings": NOTIFICATION_SETTINGS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.create_job.assert_called_once_with(expected) + assert JOB_ID == tis[TASK_ID].xcom_pull() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): + with dag_maker("test_xcomarg", render_template_as_native_obj=True): + + @task + def push_json() -> dict: + return { + "name": JOB_NAME, + "description": JOB_DESCRIPTION, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "notification_settings": NOTIFICATION_SETTINGS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + + json = push_json() + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + + db_mock = db_mock_class.return_value + db_mock.create_job.return_value = JOB_ID + + db_mock.find_job_id_by_name.return_value = None + + dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) + tis = {ti.task_id: ti for ti in dagrun.task_instances} + tis["push_json"].run() + tis[TASK_ID].run() + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "description": JOB_DESCRIPTION, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "notification_settings": NOTIFICATION_SETTINGS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.create_job.assert_called_once_with(expected) + assert JOB_ID == tis[TASK_ID].xcom_pull() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_create(self, db_mock_class): From bd2bc46f8284f96c40812643cc0e5b7b9547f1a6 Mon Sep 17 00:00:00 2001 From: boraberke Date: Thu, 27 Jun 2024 22:08:37 +0300 Subject: [PATCH 5/9] remove unnecessary variables --- tests/providers/databricks/operators/test_databricks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 6f213eedee5e1..7c3c538a41932 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -979,7 +979,7 @@ def test_validate_json_with_bad_type(self): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True) as dag: + with dag_maker("test_templated", render_template_as_native_obj=True): push_json = PythonOperator( task_id="push_json", python_callable=lambda: { @@ -1016,7 +1016,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): - with dag_maker("test_xcomarg", render_template_as_native_obj=True) as dag: + with dag_maker("test_xcomarg", render_template_as_native_obj=True): @task def push_json() -> dict: @@ -1489,7 +1489,7 @@ def test_validate_json_exception_with_job_name_and_job_id(self): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True) as dag: + with dag_maker("test_templated", render_template_as_native_obj=True): push_json = PythonOperator( task_id="push_json", python_callable=lambda: { @@ -1533,7 +1533,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): - with dag_maker("test_xcomarg", render_template_as_native_obj=True) as dag: + with dag_maker("test_xcomarg", render_template_as_native_obj=True): @task def push_json() -> dict[str, str]: From 5f2acdc830c8a3d96c0449afa32e741c8ade7bc1 Mon Sep 17 00:00:00 2001 From: boraberke Date: Thu, 27 Jun 2024 22:29:15 +0300 Subject: [PATCH 6/9] remove unused dag variables --- tests/providers/databricks/operators/test_databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 7c3c538a41932..1e8cb5b963d4b 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1499,7 +1499,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): "job_id": JOB_ID, }, ) - op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) push_json >> op db_mock = db_mock_class.return_value @@ -1546,7 +1546,7 @@ def push_json() -> dict[str, str]: json = push_json() - op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID From 2fa9c6e2711cedee0892d58701d8769c05796b7f Mon Sep 17 00:00:00 2001 From: boraberke Date: Fri, 28 Jun 2024 01:57:26 +0300 Subject: [PATCH 7/9] fix static checks --- tests/providers/databricks/operators/test_databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 1e8cb5b963d4b..6eed62768399f 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1536,7 +1536,7 @@ def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): with dag_maker("test_xcomarg", render_template_as_native_obj=True): @task - def push_json() -> dict[str, str]: + def push_json() -> dict: return { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, From b1bd751f8060d74ab05583fba037b8ae4e4384a1 Mon Sep 17 00:00:00 2001 From: boraberke Date: Fri, 28 Jun 2024 02:58:57 +0300 Subject: [PATCH 8/9] fix failing compatibility tests --- tests/providers/databricks/operators/test_databricks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 6eed62768399f..b91e8074990cf 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -992,6 +992,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock = db_mock_class.return_value db_mock.submit_run.return_value = RUN_ID + db_mock.get_run_page_url.return_value = RUN_PAGE_URL db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) @@ -1030,6 +1031,7 @@ def push_json() -> dict: op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value db_mock.submit_run.return_value = RUN_ID + db_mock.get_run_page_url.return_value = RUN_PAGE_URL db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) @@ -1504,6 +1506,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID + db_mock.get_run_page_url.return_value = RUN_PAGE_URL db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) @@ -1550,6 +1553,7 @@ def push_json() -> dict: db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID + db_mock.get_run_page_url.return_value = RUN_PAGE_URL db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) From 1cfaf4c58d8522ab7c24dcc636873f6003a1cf9c Mon Sep 17 00:00:00 2001 From: boraberke Date: Fri, 28 Jun 2024 14:14:46 +0300 Subject: [PATCH 9/9] gather repeating json logic into single functions --- .../databricks/operators/databricks.py | 46 ++++--- .../providers/databricks/utils/databricks.py | 4 +- .../databricks/operators/test_databricks.py | 130 +++++++++--------- .../databricks/utils/test_databricks.py | 4 +- 4 files changed, 93 insertions(+), 91 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 38e51ef105878..666e08d3f0123 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -36,7 +36,7 @@ WorkflowRunMetadata, ) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger -from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event +from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event if TYPE_CHECKING: from airflow.models.taskinstancekey import TaskInstanceKey @@ -182,6 +182,17 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) raise AirflowException(error_message) +def _handle_overridden_json_params(operator): + for key, value in operator.overridden_json_params.items(): + if value is not None: + operator.json[key] = value + + +def normalise_json_content(operator): + if operator.json: + operator.json = _normalise_json_content(operator.json) + + class DatabricksJobRunLink(BaseOperatorLink): """Constructs a link to monitor a Databricks Job Run.""" @@ -284,7 +295,7 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args - self._overridden_json_params = { + self.overridden_json_params = { "name": name, "description": description, "tags": tags, @@ -311,15 +322,12 @@ def _hook(self): ) def _setup_and_validate_json(self): - for key, value in self._overridden_json_params.items(): - if value is not None: - self.json[key] = value + _handle_overridden_json_params(self) if "name" not in self.json: raise AirflowException("Missing required parameter: name") - if self.json: - self.json = normalise_json_content(self.json) + normalise_json_content(self) def execute(self, context: Context) -> int: self._setup_and_validate_json() @@ -330,7 +338,7 @@ def execute(self, context: Context) -> int: self._hook.reset_job(str(job_id), self.json) if (access_control_list := self.json.get("access_control_list")) is not None: acl_json = {"access_control_list": access_control_list} - self._hook.update_job_permission(job_id, normalise_json_content(acl_json)) + self._hook.update_job_permission(job_id, _normalise_json_content(acl_json)) return job_id @@ -503,7 +511,7 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable - self._overridden_json_params = { + self.overridden_json_params = { "tasks": tasks, "spark_jar_task": spark_jar_task, "notebook_task": notebook_task, @@ -539,9 +547,7 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def _setup_and_validate_json(self): - for key, value in self._overridden_json_params.items(): - if value is not None: - self.json[key] = value + _handle_overridden_json_params(self) if "run_name" not in self.json or self.json["run_name"] is None: self.json["run_name"] = self.task_id @@ -555,8 +561,7 @@ def _setup_and_validate_json(self): ): raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'") - if self.json: - self.json = normalise_json_content(self.json) + normalise_json_content(self) def execute(self, context: Context): self._setup_and_validate_json() @@ -569,7 +574,7 @@ def execute(self, context: Context): pipeline_name = self.json["pipeline_task"]["pipeline_name"] self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name) del self.json["pipeline_task"]["pipeline_name"] - json_normalised = normalise_json_content(self.json) + json_normalised = _normalise_json_content(self.json) self.run_id = self._hook.submit_run(json_normalised) if self.deferrable: _handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context) @@ -605,7 +610,7 @@ def __init__(self, *args, **kwargs): def execute(self, context): hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator") - json_normalised = normalise_json_content(self.json) + json_normalised = _normalise_json_content(self.json) self.run_id = hook.submit_run(json_normalised) _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) @@ -805,7 +810,7 @@ def __init__( self.deferrable = deferrable self.repair_run = repair_run self.cancel_previous_runs = cancel_previous_runs - self._overridden_json_params = { + self.overridden_json_params = { "job_id": job_id, "job_name": job_name, "notebook_params": notebook_params, @@ -833,15 +838,12 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def _setup_and_validate_json(self): - for key, value in self._overridden_json_params.items(): - if value is not None: - self.json[key] = value + _handle_overridden_json_params(self) if "job_id" in self.json and "job_name" in self.json: raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") - if self.json: - self.json = normalise_json_content(self.json) + normalise_json_content(self) def execute(self, context: Context): self._setup_and_validate_json() diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index 88d622c3bc1fb..ec99bce17873c 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -21,7 +21,7 @@ from airflow.providers.databricks.hooks.databricks import RunState -def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict: +def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict: """ Normalize content or all values of content if it is a dict to a string. @@ -33,7 +33,7 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis The only one exception is when we have boolean values, they can not be converted to string type because databricks does not understand 'True' or 'False' values. """ - normalise = normalise_json_content + normalise = _normalise_json_content if isinstance(content, (str, bool)): return content elif isinstance(content, (int, float)): diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index b91e8074990cf..ae2bb4976669c 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -271,7 +271,7 @@ def test_validate_json_with_named_parameters(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": JOB_NAME, "tags": TAGS, @@ -309,7 +309,7 @@ def test_validate_json_with_json(self): op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": JOB_NAME, "tags": TAGS, @@ -375,7 +375,7 @@ def test_validate_json_with_merging(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": override_name, "tags": override_tags, @@ -401,7 +401,7 @@ def test_validate_json_with_templating(self): op.render_template_fields(context={"ds": DATE}) op._setup_and_validate_json() - expected = utils.normalise_json_content({"name": f"test-{DATE}"}) + expected = utils._normalise_json_content({"name": f"test-{DATE}"}) assert expected == op.json def test_validate_json_with_bad_type(self): @@ -455,7 +455,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): tis["push_json"].run() tis[TASK_ID].run() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": JOB_NAME, "description": JOB_DESCRIPTION, @@ -518,7 +518,7 @@ def push_json() -> dict: tis["push_json"].run() tis[TASK_ID].run() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": JOB_NAME, "description": JOB_DESCRIPTION, @@ -574,7 +574,7 @@ def test_exec_create(self, db_mock_class): return_result = op.execute({}) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": JOB_NAME, "description": JOB_DESCRIPTION, @@ -628,7 +628,7 @@ def test_exec_reset(self, db_mock_class): return_result = op.execute({}) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "name": JOB_NAME, "description": JOB_DESCRIPTION, @@ -680,7 +680,7 @@ def test_exec_update_job_permission(self, db_mock_class): op.execute({}) - expected = utils.normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST}) + expected = utils._normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST}) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, @@ -736,11 +736,11 @@ def test_validate_json_with_notebook_task_named_parameters(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_spark_python_task_named_parameters(self): """ @@ -751,11 +751,11 @@ def test_validate_json_with_spark_python_task_named_parameters(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "spark_python_task": SPARK_PYTHON_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_pipeline_name_task_named_parameters(self): """ @@ -764,9 +764,9 @@ def test_validate_json_with_pipeline_name_task_named_parameters(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_NAME_TASK) op._setup_and_validate_json() - expected = utils.normalise_json_content({"pipeline_task": PIPELINE_NAME_TASK, "run_name": TASK_ID}) + expected = utils._normalise_json_content({"pipeline_task": PIPELINE_NAME_TASK, "run_name": TASK_ID}) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_pipeline_id_task_named_parameters(self): """ @@ -775,9 +775,9 @@ def test_validate_json_with_pipeline_id_task_named_parameters(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_ID_TASK) op._setup_and_validate_json() - expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) + expected = utils._normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_spark_submit_task_named_parameters(self): """ @@ -788,11 +788,11 @@ def test_validate_json_with_spark_submit_task_named_parameters(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "spark_submit_task": SPARK_SUBMIT_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_dbt_task_named_parameters(self): """ @@ -808,11 +808,11 @@ def test_validate_json_with_dbt_task_named_parameters(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_dbt_task_mixed_parameters(self): """ @@ -829,11 +829,11 @@ def test_validate_json_with_dbt_task_mixed_parameters(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_dbt_task_without_git_source_raises_error(self): """ @@ -863,18 +863,18 @@ def test_validate_json_with_json(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_tasks(self): tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}] op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks) op._setup_and_validate_json() - expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks": tasks}) - assert expected == utils.normalise_json_content(op.json) + expected = utils._normalise_json_content({"run_name": TASK_ID, "tasks": tasks}) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_specified_run_name(self): """ @@ -884,10 +884,10 @@ def test_validate_json_with_specified_run_name(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_pipeline_task(self): """ @@ -898,10 +898,10 @@ def test_validate_json_with_pipeline_task(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task, "run_name": RUN_NAME} ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_merging(self): """ @@ -917,14 +917,14 @@ def test_validate_json_with_merging(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "new_cluster": override_new_cluster, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID, } ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) @pytest.mark.db_test def test_validate_json_with_templating(self): @@ -937,14 +937,14 @@ def test_validate_json_with_templating(self): op._setup_and_validate_json() op.render_template_fields(context={"ds": DATE}) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "new_cluster": NEW_CLUSTER, "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK, "run_name": TASK_ID, } ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_git_source(self): json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} @@ -956,7 +956,7 @@ def test_validate_json_with_git_source(self): op = DatabricksSubmitRunOperator(task_id=TASK_ID, git_source=git_source, json=json) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, @@ -964,7 +964,7 @@ def test_validate_json_with_git_source(self): "git_source": git_source, } ) - assert expected == utils.normalise_json_content(op.json) + assert expected == utils._normalise_json_content(op.json) def test_validate_json_with_bad_type(self): json = {"test": datetime.now()} @@ -1000,7 +1000,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): tis["push_json"].run() tis[TASK_ID].run() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1039,7 +1039,7 @@ def push_json() -> dict: tis["push_json"].run() tis[TASK_ID].run() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1070,7 +1070,7 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1100,7 +1100,7 @@ def test_exec_pipeline_name(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) + expected = utils._normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, @@ -1132,7 +1132,7 @@ def test_exec_failure(self, db_mock_class): with pytest.raises(AirflowException): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, @@ -1180,7 +1180,7 @@ def test_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1209,7 +1209,7 @@ def test_no_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1243,7 +1243,7 @@ def test_execute_task_deferred(self, db_mock_class): assert isinstance(exc.value.trigger, DatabricksExecutionTrigger) assert exc.value.method_name == "execute_complete" - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1325,7 +1325,7 @@ def test_databricks_submit_run_deferrable_operator_failed_before_defer(self, moc db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1355,7 +1355,7 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1380,7 +1380,7 @@ def test_validate_json_with_named_parameters(self): op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID) op._setup_and_validate_json() - expected = utils.normalise_json_content({"job_id": 42}) + expected = utils._normalise_json_content({"job_id": 42}) assert expected == op.json @@ -1399,7 +1399,7 @@ def test_validate_json_with_json(self): op = DatabricksRunNowOperator(task_id=TASK_ID, json=json) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "jar_params": JAR_PARAMS, @@ -1433,7 +1433,7 @@ def test_validate_json_with_merging(self): ) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": override_notebook_params, "jar_params": override_jar_params, @@ -1453,7 +1453,7 @@ def test_validate_json_with_templating(self): op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) op.render_template_fields(context={"ds": DATE}) op._setup_and_validate_json() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "jar_params": RENDERED_TEMPLATED_JAR_PARAMS, @@ -1514,7 +1514,7 @@ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): tis["push_json"].run() tis[TASK_ID].run() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1561,7 +1561,7 @@ def push_json() -> dict: tis["push_json"].run() tis[TASK_ID].run() - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1594,7 +1594,7 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1629,7 +1629,7 @@ def test_exec_failure(self, db_mock_class): with pytest.raises(AirflowException): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1685,7 +1685,7 @@ def test_exec_failure_with_message(self, db_mock_class): with pytest.raises(AirflowException, match="Exception: Something went wrong"): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1753,7 +1753,7 @@ def test_exec_multiple_failures_with_message(self, db_mock_class): ): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1797,7 +1797,7 @@ def test_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1828,7 +1828,7 @@ def test_no_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1859,7 +1859,7 @@ def test_exec_with_job_name(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1907,7 +1907,7 @@ def test_cancel_previous_runs(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1941,7 +1941,7 @@ def test_no_cancel_previous_runs(self, db_mock_class): op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1978,7 +1978,7 @@ def test_execute_task_deferred(self, db_mock_class): assert isinstance(exc.value.trigger, DatabricksExecutionTrigger) assert exc.value.method_name == "execute_complete" - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2088,7 +2088,7 @@ def test_databricks_run_now_deferrable_operator_failed_before_defer(self, mock_d db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2122,7 +2122,7 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ op.execute(None) - expected = utils.normalise_json_content( + expected = utils._normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, diff --git a/tests/providers/databricks/utils/test_databricks.py b/tests/providers/databricks/utils/test_databricks.py index 8c6ce8ce4ba59..4b57573253d47 100644 --- a/tests/providers/databricks/utils/test_databricks.py +++ b/tests/providers/databricks/utils/test_databricks.py @@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException from airflow.providers.databricks.hooks.databricks import RunState -from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event +from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event RUN_ID = 1 RUN_PAGE_URL = "run-page-url" @@ -46,7 +46,7 @@ def test_normalise_json_content(self): "test_list": ["1", "1.0", "a", "b"], "test_tuple": ["1", "1.0", "a", "b"], } - assert normalise_json_content(test_json) == expected + assert _normalise_json_content(test_json) == expected def test_validate_trigger_event_success(self): event = {