diff --git a/airflow/providers/databricks/plugins/databricks_workflow.py b/airflow/providers/databricks/plugins/databricks_workflow.py index 41c7b6735759..be2eaecfe3ba 100644 --- a/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/airflow/providers/databricks/plugins/databricks_workflow.py @@ -26,7 +26,6 @@ from flask_appbuilder.api import expose from packaging.version import Version -from airflow.configuration import conf from airflow.exceptions import AirflowException, TaskInstanceNotFound from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.dag import DAG, clear_task_instances @@ -413,8 +412,7 @@ class RepairDatabricksTasks(AirflowBaseView, LoggingMixin): @expose("/repair_databricks_job//", methods=("GET",)) @get_auth_decorator() def repair(self, dag_id: str, run_id: str): - view = conf.get("webserver", "dag_default_view") - return_url = self._get_return_url(dag_id, view) + return_url = self._get_return_url(dag_id, run_id) tasks_to_repair = request.values.get("tasks_to_repair") self.log.info("Tasks to repair: %s", tasks_to_repair) @@ -450,8 +448,8 @@ def repair(self, dag_id: str, run_id: str): return redirect(return_url) @staticmethod - def _get_return_url(dag_id: str, view) -> str: - return f"/dags/{dag_id}/{view}" + def _get_return_url(dag_id: str, run_id: str) -> str: + return url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id) repair_databricks_view = RepairDatabricksTasks() diff --git a/tests/providers/databricks/plugins/test_databricks_workflow.py b/tests/providers/databricks/plugins/test_databricks_workflow.py index ad14be2bd568..c90caaceee86 100644 --- a/tests/providers/databricks/plugins/test_databricks_workflow.py +++ b/tests/providers/databricks/plugins/test_databricks_workflow.py @@ -20,6 +20,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from flask import url_for from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun @@ -144,6 +145,18 @@ def test_get_task_instance(app): assert result == dag_run +@pytest.mark.db_test +def test_get_return_url_dag_id_run_id(app): + dag_id = "example_dag" + run_id = "example_run" + + expected_url = url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id) + + with app.app_context(): + actual_url = RepairDatabricksTasks._get_return_url(dag_id, run_id) + assert actual_url == expected_url, f"Expected {expected_url}, got {actual_url}" + + @pytest.mark.db_test def test_workflow_job_run_link(app): with app.app_context():