diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index e6fb66962d1ee..5ecdf14f59051 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -17,11 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from functools import cached_property +from typing import TYPE_CHECKING, Any import attrs import structlog +from airflow.exceptions import AirflowException from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.triggers.base import StartTriggerArgs from airflow.utils.helpers import prevent_duplicates @@ -29,6 +31,8 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from airflow.models import TaskInstance + from airflow.sdk import BaseOperatorLink from airflow.sdk.definitions.context import Context log = structlog.get_logger(__name__) @@ -118,3 +122,54 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St next_kwargs=next_kwargs, timeout=timeout, ) + + @cached_property + def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]: + """Returns dictionary of all extra links for the operator.""" + op_extra_links_from_plugin: dict[str, Any] = {} + from airflow import plugins_manager + + plugins_manager.initialize_extra_operators_links_plugins() + if plugins_manager.operator_extra_links is None: + raise AirflowException("Can't load operators") + operator_class_type = self.operator_class["task_type"] # type: ignore + for ope in plugins_manager.operator_extra_links: + if ope.operators and any(operator_class_type in cls.__name__ for cls in ope.operators): + op_extra_links_from_plugin.update({ope.name: ope}) + + operator_extra_links_all = {link.name: link for link in self.operator_extra_links} + # Extra links defined in Plugins overrides operator links defined in operator + operator_extra_links_all.update(op_extra_links_from_plugin) + + return operator_extra_links_all + + @cached_property + def global_operator_extra_link_dict(self) -> dict[str, Any]: + """Returns dictionary of all global extra links.""" + from airflow import plugins_manager + + plugins_manager.initialize_extra_operators_links_plugins() + if plugins_manager.global_operator_extra_links is None: + raise AirflowException("Can't load operators") + return {link.name: link for link in plugins_manager.global_operator_extra_links} + + @cached_property + def extra_links(self) -> list[str]: + return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) + + def get_extra_links(self, ti: TaskInstance, name: str) -> str | None: + """ + For an operator, gets the URLs that the ``extra_links`` entry points to. + + :meta private: + + :raise ValueError: The error message of a ValueError will be passed on through to + the fronted to show up as a tooltip on the disabled link. + :param ti: The TaskInstance for the URL being searched for. + :param name: The name of the link we're looking for the URL for. Should be + one of the options specified in ``extra_links``. + """ + link = self.operator_extra_link_dict.get(name) or self.global_operator_extra_link_dict.get(name) + if not link: + return None + return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index d1d2dde78520a..6600e8e7edd6f 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -91,8 +91,10 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker from tests_common.test_utils.mock_operators import ( + AirflowLink, AirflowLink2, CustomOperator, + GithubLink, GoogleLink, MockOperator, ) @@ -3095,6 +3097,13 @@ def operator_extra_links(self): XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2") ] + mapped_task = deserialized_dag.task_dict["task"] + assert mapped_task.operator_extra_link_dict == { + "airflow": XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2") + } + assert mapped_task.global_operator_extra_link_dict == {"airflow": AirflowLink(), "github": GithubLink()} + assert mapped_task.extra_links == sorted({"airflow", "github"}) + def test_handle_v1_serdag(): v1 = {