diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index a263fa9106a11..d322519230d25 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -35,10 +35,6 @@ DatabricksWorkflowTaskGroup, WorkflowRunMetadata, ) -from airflow.providers.databricks.plugins.databricks_workflow import ( - WorkflowJobRepairSingleTaskLink, - WorkflowJobRunLink, -) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event @@ -962,15 +958,6 @@ def __init__( super().__init__(**kwargs) - if self._databricks_workflow_task_group is not None: - self.operator_extra_links = ( - WorkflowJobRunLink(), - WorkflowJobRepairSingleTaskLink(), - ) - else: - # Databricks does not support repair for non-workflow tasks, hence do not show the repair link. - self.operator_extra_links = (DatabricksJobRunLink(),) - @cached_property def _hook(self) -> DatabricksHook: return self._get_hook(caller=self.caller) @@ -1029,17 +1016,12 @@ def _get_run_json(self) -> dict[str, Any]: raise ValueError("Must specify either existing_cluster_id or new_cluster.") return run_json - def _launch_job(self, context: Context | None = None) -> int: + def _launch_job(self) -> int: """Launch the job on Databricks.""" run_json = self._get_run_json() self.databricks_run_id = self._hook.submit_run(run_json) url = self._hook.get_run_page_url(self.databricks_run_id) self.log.info("Check the job run in Databricks: %s", url) - - if self.do_xcom_push and context is not None: - context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=self.databricks_run_id) - context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=url) - return self.databricks_run_id def _handle_terminal_run_state(self, run_state: RunState) -> None: @@ -1058,15 +1040,7 @@ def _get_current_databricks_task(self) -> dict[str, Any]: """Retrieve the Databricks task corresponding to the current Airflow task.""" if self.databricks_run_id is None: raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") - tasks = self._hook.get_run(self.databricks_run_id)["tasks"] - - # Because the task_key remains the same across multiple runs, and the Databricks API does not return - # tasks sorted by their attempts/start time, we sort the tasks by start time. This ensures that we - # map the latest attempt (whose status is to be monitored) of the task run to the task_key while - # building the {task_key: task} map below. - sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"]) - - return {task["task_key"]: task for task in sorted_task_runs}[ + return {task["task_key"]: task for task in self._hook.get_run(self.databricks_run_id)["tasks"]}[ self._get_databricks_task_id(self.task_id) ] @@ -1151,7 +1125,7 @@ def execute(self, context: Context) -> None: self.databricks_run_id = workflow_run_metadata.run_id self.databricks_conn_id = workflow_run_metadata.conn_id else: - self._launch_job(context=context) + self._launch_job() if self.wait_for_termination: self.monitor_databricks_job() diff --git a/airflow/providers/databricks/operators/databricks_workflow.py b/airflow/providers/databricks/operators/databricks_workflow.py index 15333dc69118b..8203145314fd0 100644 --- a/airflow/providers/databricks/operators/databricks_workflow.py +++ b/airflow/providers/databricks/operators/databricks_workflow.py @@ -28,10 +28,6 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState -from airflow.providers.databricks.plugins.databricks_workflow import ( - WorkflowJobRepairAllFailedLink, - WorkflowJobRunLink, -) from airflow.utils.task_group import TaskGroup if TYPE_CHECKING: @@ -92,7 +88,6 @@ class _CreateDatabricksWorkflowOperator(BaseOperator): populated after instantiation using the `add_task` method. """ - operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink()) template_fields = ("notebook_params",) caller = "_CreateDatabricksWorkflowOperator" diff --git a/airflow/providers/databricks/plugins/__init__.py b/airflow/providers/databricks/plugins/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/airflow/providers/databricks/plugins/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/databricks/plugins/databricks_workflow.py b/airflow/providers/databricks/plugins/databricks_workflow.py deleted file mode 100644 index 186f14d02afdb..0000000000000 --- a/airflow/providers/databricks/plugins/databricks_workflow.py +++ /dev/null @@ -1,457 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import logging -import os -from operator import itemgetter -from typing import TYPE_CHECKING, Any, cast - -from flask import current_app, flash, redirect, request, url_for -from flask_appbuilder.api import expose - -from airflow.configuration import conf -from airflow.exceptions import AirflowException, TaskInstanceNotFound -from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.dag import DAG, clear_task_instances -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance, TaskInstanceKey -from airflow.models.xcom import XCom -from airflow.plugins_manager import AirflowPlugin -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.utils.airflow_flask_app import AirflowApp -from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import TaskInstanceState -from airflow.utils.task_group import TaskGroup -from airflow.www.views import AirflowBaseView - -if TYPE_CHECKING: - from sqlalchemy.orm.session import Session - - -REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20) -REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5) - -airflow_app = cast(AirflowApp, current_app) - - -def _get_databricks_task_id(task: BaseOperator) -> str: - """ - Get the databricks task ID using dag_id and task_id. removes illegal characters. - - :param task: The task to get the databricks task ID for. - :return: The databricks task ID. - """ - return f"{task.dag_id}__{task.task_id.replace('.', '__')}" - - -def get_databricks_task_ids( - group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger -) -> list[str]: - """ - Return a list of all Databricks task IDs for a dictionary of Airflow tasks. - - :param group_id: The task group ID. - :param task_map: A dictionary mapping task IDs to BaseOperator instances. - :param log: The logger to use for logging. - :return: A list of Databricks task IDs for the given task group. - """ - task_ids = [] - log.debug("Getting databricks task ids for group %s", group_id) - for task_id, task in task_map.items(): - if task_id == f"{group_id}.launch": - continue - databricks_task_id = _get_databricks_task_id(task) - log.debug("databricks task id for task %s is %s", task_id, databricks_task_id) - task_ids.append(databricks_task_id) - return task_ids - - -@provide_session -def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun: - """ - Retrieve the DagRun object associated with the specified DAG and run_id. - - :param dag: The DAG object associated with the DagRun to retrieve. - :param run_id: The run_id associated with the DagRun to retrieve. - :param session: The SQLAlchemy session to use for the query. If None, uses the default session. - :return: The DagRun object associated with the specified DAG and run_id. - """ - if not session: - raise AirflowException("Session not provided.") - - return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first() - - -@provide_session -def _clear_task_instances( - dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None -) -> None: - dag = airflow_app.dag_bag.get_dag(dag_id) - log.debug("task_ids %s to clear", str(task_ids)) - dr: DagRun = _get_dagrun(dag, run_id, session=session) - tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids] - clear_task_instances(tis_to_clear, session) - - -def _repair_task( - databricks_conn_id: str, - databricks_run_id: int, - tasks_to_repair: list[str], - logger: logging.Logger, -) -> int: - """ - Repair a Databricks task using the Databricks API. - - This function allows the Airflow retry function to create a repair job for Databricks. - It uses the Databricks API to get the latest repair ID before sending the repair query. - - :param databricks_conn_id: The Databricks connection ID. - :param databricks_run_id: The Databricks run ID. - :param tasks_to_repair: A list of Databricks task IDs to repair. - :param logger: The logger to use for logging. - :return: None - """ - hook = DatabricksHook(databricks_conn_id=databricks_conn_id) - - repair_history_id = hook.get_latest_repair_id(databricks_run_id) - logger.debug("Latest repair ID is %s", repair_history_id) - logger.debug( - "Sending repair query for tasks %s on run %s", - tasks_to_repair, - databricks_run_id, - ) - - repair_json = { - "run_id": databricks_run_id, - "latest_repair_id": repair_history_id, - "rerun_tasks": tasks_to_repair, - } - - return hook.repair_run(repair_json) - - -def get_launch_task_id(task_group: TaskGroup) -> str: - """ - Retrieve the launch task ID from the current task group or a parent task group, recursively. - - :param task_group: Task Group to be inspected - :return: launch Task ID - """ - try: - launch_task_id = task_group.get_child_by_label("launch").task_id # type: ignore[attr-defined] - except KeyError as e: - if not task_group.parent_group: - raise AirflowException("No launch task can be found in the task group.") from e - launch_task_id = get_launch_task_id(task_group.parent_group) - - return launch_task_id - - -def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> TaskInstanceKey: - """ - Return the task key for the launch task. - - This allows us to gather databricks Metadata even if the current task has failed (since tasks only - create xcom values if they succeed). - - :param current_task_key: The task key for the current task. - :param task_id: The task ID for the current task. - :return: The task key for the launch task. - """ - if task_id: - return TaskInstanceKey( - dag_id=current_task_key.dag_id, - task_id=task_id, - run_id=current_task_key.run_id, - try_number=current_task_key.try_number, - ) - - return current_task_key - - -@provide_session -def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance: - dag_id = operator.dag.dag_id - dag_run = DagRun.find(dag_id, execution_date=dttm)[0] - ti = ( - session.query(TaskInstance) - .filter( - TaskInstance.dag_id == dag_id, - TaskInstance.run_id == dag_run.run_id, - TaskInstance.task_id == operator.task_id, - ) - .one_or_none() - ) - if not ti: - raise TaskInstanceNotFound("Task instance not found") - return ti - - -def get_xcom_result( - ti_key: TaskInstanceKey, - key: str, - ti: TaskInstance | None, -) -> Any: - result = XCom.get_value( - ti_key=ti_key, - key=key, - ) - from airflow.providers.databricks.operators.databricks_workflow import WorkflowRunMetadata - - return WorkflowRunMetadata(**result) - - -class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin): - """Constructs a link to monitor a Databricks Job Run.""" - - name = "See Databricks Job Run" - - def get_link( - self, - operator: BaseOperator, - dttm=None, - *, - ti_key: TaskInstanceKey | None = None, - ) -> str: - ti = None - if not ti_key: - ti = get_task_instance(operator, dttm) - ti_key = ti.key - task_group = operator.task_group - - if not task_group: - raise AirflowException("Task group is required for generating Databricks Workflow Job Run Link.") - - dag = airflow_app.dag_bag.get_dag(ti_key.dag_id) - dag.get_task(ti_key.task_id) - self.log.info("Getting link for task %s", ti_key.task_id) - if ".launch" not in ti_key.task_id: - self.log.debug("Finding the launch task for job run metadata %s", ti_key.task_id) - launch_task_id = get_launch_task_id(task_group) - ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id) - # Should we catch the exception here if there is no return value? - metadata = get_xcom_result(ti_key, "return_value", ti) - - hook = DatabricksHook(metadata.conn_id) - return f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}" - - -class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin): - """Constructs a link to send a request to repair all failed tasks in the Databricks workflow.""" - - name = "Repair All Failed Tasks" - - def get_link( - self, - operator, - dttm=None, - *, - ti_key: TaskInstanceKey | None = None, - ) -> str: - ti = None - if not ti_key: - ti = get_task_instance(operator, dttm) - ti_key = ti.key - task_group = operator.task_group - self.log.debug( - "Creating link to repair all tasks for databricks job run %s", - task_group.group_id, - ) - # Should we catch the exception here if there is no return value? - metadata = get_xcom_result(ti_key, "return_value", ti) - - tasks_str = self.get_tasks_to_run(ti_key, operator, self.log) - self.log.debug("tasks to rerun: %s", tasks_str) - - query_params = { - "dag_id": ti_key.dag_id, - "databricks_conn_id": metadata.conn_id, - "databricks_run_id": metadata.run_id, - "run_id": ti_key.run_id, - "tasks_to_repair": tasks_str, - } - - return url_for("RepairDatabricksTasks.repair", **query_params) - - @classmethod - def get_task_group_children(cls, task_group: TaskGroup) -> dict[str, BaseOperator]: - """ - Given a TaskGroup, return children which are Tasks, inspecting recursively any TaskGroups within. - - :param task_group: An Airflow TaskGroup - :return: Dictionary that contains Task IDs as keys and Tasks as values. - """ - children: dict[str, Any] = {} - for child_id, child in task_group.children.items(): - if isinstance(child, TaskGroup): - child_children = cls.get_task_group_children(child) - children = {**children, **child_children} - else: - children[child_id] = child - return children - - def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log: logging.Logger) -> str: - task_group = operator.task_group - if not task_group: - raise AirflowException("Task group is required for generating repair link.") - if not task_group.group_id: - raise AirflowException("Task group ID is required for generating repair link.") - dag = airflow_app.dag_bag.get_dag(ti_key.dag_id) - dr = _get_dagrun(dag, ti_key.run_id) - log.debug("Getting failed and skipped tasks for dag run %s", dr.run_id) - task_group_sub_tasks = self.get_task_group_children(task_group).items() - failed_and_skipped_tasks = self._get_failed_and_skipped_tasks(dr) - log.debug("Failed and skipped tasks: %s", failed_and_skipped_tasks) - - tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks} - - return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log)) - - @staticmethod - def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]: - """ - Return a list of task IDs for tasks that have failed or have been skipped in the given DagRun. - - :param dr: The DagRun object for which to retrieve failed and skipped tasks. - - :return: A list of task IDs for tasks that have failed or have been skipped. - """ - return [ - t.task_id - for t in dr.get_task_instances( - state=[ - TaskInstanceState.FAILED, - TaskInstanceState.SKIPPED, - TaskInstanceState.UP_FOR_RETRY, - TaskInstanceState.UPSTREAM_FAILED, - None, - ], - ) - ] - - -class WorkflowJobRepairSingleTaskLink(BaseOperatorLink, LoggingMixin): - """Construct a link to send a repair request for a single databricks task.""" - - name = "Repair a single task" - - def get_link( - self, - operator, - dttm=None, - *, - ti_key: TaskInstanceKey | None = None, - ) -> str: - ti = None - if not ti_key: - ti = get_task_instance(operator, dttm) - ti_key = ti.key - - task_group = operator.task_group - if not task_group: - raise AirflowException("Task group is required for generating repair link.") - - self.log.info( - "Creating link to repair a single task for databricks job run %s task %s", - task_group.group_id, - ti_key.task_id, - ) - dag = airflow_app.dag_bag.get_dag(ti_key.dag_id) - task = dag.get_task(ti_key.task_id) - # Should we catch the exception here if there is no return value? - if ".launch" not in ti_key.task_id: - launch_task_id = get_launch_task_id(task_group) - ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id) - metadata = get_xcom_result(ti_key, "return_value", ti) - - query_params = { - "dag_id": ti_key.dag_id, - "databricks_conn_id": metadata.conn_id, - "databricks_run_id": metadata.run_id, - "run_id": ti_key.run_id, - "tasks_to_repair": _get_databricks_task_id(task), - } - return url_for("RepairDatabricksTasks.repair", **query_params) - - -class RepairDatabricksTasks(AirflowBaseView, LoggingMixin): - """Repair databricks tasks from Airflow.""" - - default_view = "repair" - - @expose("/repair_databricks_job", methods=("GET",)) - def repair(self): - databricks_conn_id, databricks_run_id, dag_id, tasks_to_repair = itemgetter( - "databricks_conn_id", "databricks_run_id", "dag_id", "tasks_to_repair" - )(request.values) - view = conf.get("webserver", "dag_default_view") - return_url = self._get_return_url(dag_id, view) - run_id = request.values.get("run_id").replace( - " ", "+" - ) # get run id separately since we need to modify it - if not tasks_to_repair: - # If there are no tasks to repair, we return. - flash("No tasks to repair. Not sending repair request.") - return redirect(return_url) - self.log.info("Tasks to repair: %s", tasks_to_repair) - self.log.info("Repairing databricks job %s", databricks_run_id) - res = _repair_task( - databricks_conn_id=databricks_conn_id, - databricks_run_id=databricks_run_id, - tasks_to_repair=tasks_to_repair.split(","), - logger=self.log, - ) - self.log.info("Repairing databricks job query for run %s sent", databricks_run_id) - - self.log.info("Clearing tasks to rerun in airflow") - _clear_task_instances(dag_id, run_id, tasks_to_repair.split(","), self.log) - flash(f"Databricks repair job is starting!: {res}") - return redirect(return_url) - - @staticmethod - def _get_return_url(dag_id: str, view) -> str: - return f"/dags/{dag_id}/{view}" - - -repair_databricks_view = RepairDatabricksTasks() - -repair_databricks_package = { - "name": "Repair Databricks View", - "category": "Repair Databricks Plugin", - "view": repair_databricks_view, -} - - -class DatabricksWorkflowPlugin(AirflowPlugin): - """ - Databricks Workflows plugin for Airflow. - - .. seealso:: - For more information on how to use this plugin, take a look at the guide: - :ref:`howto/plugin:DatabricksWorkflowPlugin` - """ - - name = "databricks_workflow" - operator_extra_links = [ - WorkflowJobRepairAllFailedLink(), - WorkflowJobRepairSingleTaskLink(), - WorkflowJobRunLink(), - ] - appbuilder_views = [repair_databricks_package] diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 0132982659884..930813ced3284 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -165,9 +165,5 @@ connection-types: - hook-class-name: airflow.providers.databricks.hooks.databricks.DatabricksHook connection-type: databricks -plugins: - - name: databricks_workflow - plugin-class: airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin - extra-links: - airflow.providers.databricks.operators.databricks.DatabricksJobRunLink diff --git a/docs/apache-airflow-providers-databricks/img/workflow_plugin_launch_task.png b/docs/apache-airflow-providers-databricks/img/workflow_plugin_launch_task.png deleted file mode 100644 index e99083f53ffcd..0000000000000 Binary files a/docs/apache-airflow-providers-databricks/img/workflow_plugin_launch_task.png and /dev/null differ diff --git a/docs/apache-airflow-providers-databricks/img/workflow_plugin_single_task.png b/docs/apache-airflow-providers-databricks/img/workflow_plugin_single_task.png deleted file mode 100644 index 17a130b944e5f..0000000000000 Binary files a/docs/apache-airflow-providers-databricks/img/workflow_plugin_single_task.png and /dev/null differ diff --git a/docs/apache-airflow-providers-databricks/index.rst b/docs/apache-airflow-providers-databricks/index.rst index 4e010d643e794..3358bd8bb1061 100644 --- a/docs/apache-airflow-providers-databricks/index.rst +++ b/docs/apache-airflow-providers-databricks/index.rst @@ -36,7 +36,6 @@ Connection types Operators - Plugins .. toctree:: :hidden: diff --git a/docs/apache-airflow-providers-databricks/plugins/index.rst b/docs/apache-airflow-providers-databricks/plugins/index.rst deleted file mode 100644 index 5ddb65f6f3b3d..0000000000000 --- a/docs/apache-airflow-providers-databricks/plugins/index.rst +++ /dev/null @@ -1,28 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - - -Databricks Plugins -================== - - -.. toctree:: - :maxdepth: 1 - :glob: - - * diff --git a/docs/apache-airflow-providers-databricks/plugins/workflow.rst b/docs/apache-airflow-providers-databricks/plugins/workflow.rst deleted file mode 100644 index 22acd05596791..0000000000000 --- a/docs/apache-airflow-providers-databricks/plugins/workflow.rst +++ /dev/null @@ -1,60 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. _howto/plugin:DatabricksWorkflowPlugin: - - -DatabricksWorkflowPlugin -======================== - - -Overview --------- - -The ``DatabricksWorkflowPlugin`` enhances the Airflow UI by providing links for tasks that allow users to view the -Databricks job run in the Databricks workspace. Additionally, it offers links to repair task(s) within the workflow. - -Features --------- - -- **Task-Level Links**: Within the workflow, each task includes links to the job run and a repair link for the individual task. - -- **Workflow-Level Links**: At the workflow level, for the job launch task, the plugin provides a link to repair all failed tasks and a link to the job run(allows users to monitor the job in the Databricks account) in the Databricks workspace. - -Examples --------- - -- **Job Run Link and Repair link for Single Task**: - -.. image:: ../img/workflow_plugin_single_task.png - -- **Workflow-Level Links to the job run and to repair all failed tasks**: - -.. image:: ../img/workflow_plugin_launch_task.png - -Notes ------ - -Databricks does not allow repairing jobs with single tasks launched outside the workflow. Hence, for these tasks, only the job run link is provided. - -Usage ------ - -Ideally, installing the provider will also install the plugin, and it should work automatically in your deployment. -However, if custom configurations are preventing the use of plugins, ensure the plugin is properly installed and -configured in your Airflow environment to utilize its features. The plugin will automatically detect Databricks jobs, -as the links are embedded in the relevant operators. diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index c9de823c0035a..9d025228867ae 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -435,12 +435,7 @@ "devel-deps": [ "deltalake>=0.12.0" ], - "plugins": [ - { - "name": "databricks_workflow", - "plugin-class": "airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin" - } - ], + "plugins": [], "cross-providers-deps": [ "common.sql" ], diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index a8962b1a61797..b60a9de17999d 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -414,7 +414,7 @@ def test_does_not_double_import_entrypoint_provider_plugins(self): assert len(plugins_manager.plugins) == 0 plugins_manager.load_entrypoint_plugins() plugins_manager.load_providers_plugins() - assert len(plugins_manager.plugins) == 3 + assert len(plugins_manager.plugins) == 2 class TestPluginsDirectorySource: diff --git a/tests/providers/databricks/plugins/__init__.py b/tests/providers/databricks/plugins/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/tests/providers/databricks/plugins/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/providers/databricks/plugins/test_databricks_workflow.py b/tests/providers/databricks/plugins/test_databricks_workflow.py deleted file mode 100644 index 550b28bc196b3..0000000000000 --- a/tests/providers/databricks/plugins/test_databricks_workflow.py +++ /dev/null @@ -1,261 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from unittest.mock import MagicMock, Mock, patch - -import pytest -from flask import request - -from airflow.exceptions import AirflowException -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstanceKey -from airflow.plugins_manager import AirflowPlugin -from airflow.providers.databricks.plugins.databricks_workflow import ( - DatabricksWorkflowPlugin, - RepairDatabricksTasks, - WorkflowJobRepairSingleTaskLink, - WorkflowJobRunLink, - _get_dagrun, - _get_databricks_task_id, - _get_launch_task_key, - _repair_task, - get_databricks_task_ids, - get_launch_task_id, - get_task_instance, -) -from airflow.utils.dates import days_ago -from airflow.www.app import create_app - -DAG_ID = "test_dag" -TASK_ID = "test_task" -RUN_ID = "test_run_1" -DAG_RUN_DATE = days_ago(1) -TASK_INSTANCE_KEY = TaskInstanceKey(dag_id=DAG_ID, task_id=TASK_ID, run_id=RUN_ID, try_number=1) -DATABRICKS_CONN_ID = "databricks_default" -DATABRICKS_RUN_ID = 12345 -GROUP_ID = "test_group" -TASK_MAP = { - "task1": MagicMock(dag_id=DAG_ID, task_id="task1"), - "task2": MagicMock(dag_id=DAG_ID, task_id="task2"), -} -LOG = MagicMock() - - -@pytest.mark.parametrize( - "task, expected_id", - [ - (MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"), - (MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"), - ], -) -def test_get_databricks_task_id(task, expected_id): - result = _get_databricks_task_id(task) - - assert result == expected_id - - -def test_get_databricks_task_ids(): - result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG) - - expected_ids = ["test_dag__task1", "test_dag__task2"] - assert result == expected_ids - - -def test_get_dagrun(): - session = MagicMock() - dag = MagicMock(dag_id=DAG_ID) - session.query.return_value.filter.return_value.first.return_value = DagRun() - - result = _get_dagrun(dag, RUN_ID, session=session) - - assert isinstance(result, DagRun) - - -@patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") -def test_repair_task(mock_databricks_hook): - mock_hook_instance = mock_databricks_hook.return_value - mock_hook_instance.get_latest_repair_id.return_value = 100 - mock_hook_instance.repair_run.return_value = 200 - - tasks_to_repair = ["task1", "task2"] - result = _repair_task(DATABRICKS_CONN_ID, DATABRICKS_RUN_ID, tasks_to_repair, LOG) - - assert result == 200 - mock_hook_instance.get_latest_repair_id.assert_called_once_with(DATABRICKS_RUN_ID) - mock_hook_instance.repair_run.assert_called_once() - - -def test_get_launch_task_id_no_launch_task(): - task_group = MagicMock(get_child_by_label=MagicMock(side_effect=KeyError)) - task_group.parent_group = None - - with pytest.raises(AirflowException): - get_launch_task_id(task_group) - - -def test_get_launch_task_key(): - result = _get_launch_task_key(TASK_INSTANCE_KEY, "launch_task") - - assert isinstance(result, TaskInstanceKey) - assert result.dag_id == TASK_INSTANCE_KEY.dag_id - assert result.task_id == "launch_task" - assert result.run_id == TASK_INSTANCE_KEY.run_id - - -@pytest.fixture(scope="session") -def app(): - app = create_app(testing=True) - app.config["SERVER_NAME"] = "localhost" - - with app.app_context(): - yield app - - -def test_repair_databricks_tasks(app): - with app.test_request_context("/"): - view = RepairDatabricksTasks() - request_values = { - "databricks_conn_id": "conn_id", - "databricks_run_id": "run_id", - "run_id": "run_id", - "dag_id": "dag_id", - "tasks_to_repair": "task1,task2", - } - - with patch( - "airflow.providers.databricks.plugins.databricks_workflow._repair_task" - ) as mock_repair_task, patch( - "airflow.providers.databricks.plugins.databricks_workflow._clear_task_instances" - ) as mock_clear_task_instances, patch( - "airflow.providers.databricks.plugins.databricks_workflow.flash" - ) as mock_flash, patch( - "airflow.providers.databricks.plugins.databricks_workflow.redirect" - ) as mock_redirect: - request.values = request_values - - _ = view.repair() - - mock_repair_task.assert_called_once() - mock_clear_task_instances.assert_called_once() - mock_flash.assert_called_once() - mock_redirect.assert_called_once() - - -def test_get_task_instance(app): - with app.app_context(): - operator = Mock() - operator.dag.dag_id = "dag_id" - operator.task_id = "task_id" - dttm = "2022-01-01T00:00:00Z" - session = Mock() - dag_run = Mock() - session.query().filter().one_or_none.return_value = dag_run - - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.DagRun.find", return_value=[dag_run] - ): - result = get_task_instance(operator, dttm, session) - assert result == dag_run - - -def test_workflow_job_run_link(app): - with app.app_context(): - link = WorkflowJobRunLink() - operator = Mock() - ti_key = Mock() - ti_key.dag_id = "dag_id" - ti_key.task_id = "task_id" - ti_key.run_id = "run_id" - ti_key.try_number = 1 - - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.get_task_instance" - ) as mock_get_task_instance: - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.get_xcom_result" - ) as mock_get_xcom_result: - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.airflow_app.dag_bag.get_dag" - ) as mock_get_dag: - mock_connection = Mock() - mock_connection.extra_dejson = {"host": "mockhost"} - - with patch( - "airflow.providers.databricks.hooks.databricks.DatabricksHook.get_connection", - return_value=mock_connection, - ): - mock_get_task_instance.return_value = Mock(key=ti_key) - mock_get_xcom_result.return_value = Mock(conn_id="conn_id", run_id=1, job_id=1) - mock_get_dag.return_value.get_task = Mock(return_value=Mock(task_id="task_id")) - - result = link.get_link(operator, ti_key=ti_key) - assert "https://mockhost/#job/1/run/1" in result - - -def test_workflow_job_repair_single_failed_link(app): - with app.app_context(): - link = WorkflowJobRepairSingleTaskLink() - operator = Mock() - operator.task_group = Mock() - operator.task_group.group_id = "group_id" - operator.task_group.get_child_by_label = Mock() - ti_key = Mock() - ti_key.dag_id = "dag_id" - ti_key.task_id = "task_id" - ti_key.run_id = "run_id" - ti_key.try_number = 1 - - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.get_task_instance" - ) as mock_get_task_instance: - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.get_xcom_result" - ) as mock_get_xcom_result: - with patch( - "airflow.providers.databricks.plugins.databricks_workflow.airflow_app.dag_bag.get_dag" - ) as mock_get_dag: - mock_get_task_instance.return_value = Mock(key=ti_key) - mock_get_xcom_result.return_value = Mock(conn_id="conn_id", run_id=1) - mock_get_dag.return_value.get_task = Mock(return_value=Mock(task_id="task_id")) - - result = link.get_link(operator, ti_key=ti_key) - assert result.startswith("http://localhost/repair_databricks_job") - - -@pytest.fixture -def plugin(): - return DatabricksWorkflowPlugin() - - -def test_plugin_is_airflow_plugin(plugin): - assert isinstance(plugin, AirflowPlugin) - - -def test_operator_extra_links(plugin): - for link in plugin.operator_extra_links: - assert hasattr(link, "get_link") - - -def test_appbuilder_views(plugin): - assert plugin.appbuilder_views is not None - assert len(plugin.appbuilder_views) == 1 - - repair_view = plugin.appbuilder_views[0]["view"] - assert isinstance(repair_view, RepairDatabricksTasks) - assert repair_view.default_view == "repair"