diff --git a/airflow-core/src/airflow/configuration.py b/airflow-core/src/airflow/configuration.py index cd16658e929f0..d8ad81274679b 100644 --- a/airflow-core/src/airflow/configuration.py +++ b/airflow-core/src/airflow/configuration.py @@ -50,8 +50,8 @@ from airflow.utils.module_loading import import_string if TYPE_CHECKING: + from airflow._shared.secrets_backend.base import BaseSecretsBackend from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager - from airflow.secrets import BaseSecretsBackend log = logging.getLogger(__name__) @@ -851,6 +851,8 @@ def initialize_secrets_backends( * import secrets backend classes * instantiate them and return them in a list """ + from airflow.models.connection import Connection + backend_list = [] worker_mode = False if default_backends != DEFAULT_SECRETS_SEARCH_PATH: @@ -863,6 +865,9 @@ def initialize_secrets_backends( for class_name in default_backends: secrets_backend_cls = import_string(class_name) + if not hasattr(secrets_backend_cls, "set_connection_class"): + raise ValueError(f"{secrets_backend_cls} does not have set_connection_class method") + secrets_backend_cls.set_connection_class(Connection) backend_list.append(secrets_backend_cls()) return backend_list diff --git a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py index 0ae361ffd3267..af181f20982bd 100644 --- a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py +++ b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py @@ -22,6 +22,17 @@ class BaseSecretsBackend(ABC): """Abstract base class to retrieve Connection object given a conn_id or Variable given a key.""" + _connection_class = None + + @classmethod + def set_connection_class(cls, connection_class: type) -> None: + """ + Set the Connection class to use for deserialization. + + :param connection_class: The Connection class to use. + """ + cls._connection_class = connection_class + @staticmethod def build_path(path_prefix: str, secret_id: str, sep: str = "/") -> str: """ @@ -62,28 +73,11 @@ def get_config(self, key: str) -> str | None: """ return None - @staticmethod - def _get_connection_class(): - """ - Detect which Connection class to use based on execution context. - - Returns SDK Connection in worker context, core Connection in server context. - """ - import os - - process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", "").lower() - if process_context == "client": - # Client context (worker, dag processor, triggerer) - from airflow.sdk.definitions.connection import Connection - - return Connection - - # Server context (scheduler, API server, etc.) - from airflow.models.connection import Connection - - return Connection - - def deserialize_connection(self, conn_id: str, value: str): + def deserialize_connection( + self, + conn_id: str, + value: str, + ): """ Given a serialized representation of the airflow Connection, return an instance. @@ -94,16 +88,23 @@ def deserialize_connection(self, conn_id: str, value: str): :param value: the serialized representation of the Connection object :return: the deserialized Connection """ - conn_class = self._get_connection_class() - + if not self._connection_class: + raise ValueError( + "Connection class is not set. You must call `set_connection_class` on the class " + "before calling deserialize_connection." + ) value = value.strip() if value[0] == "{": - return conn_class.from_json(value=value, conn_id=conn_id) + if hasattr(self._connection_class, "from_json"): + return self._connection_class.from_json(value=value, conn_id=conn_id) + raise ValueError( + "Connection class does not support from_json deserialization: {self._connection_class}" + ) # TODO: Only sdk has from_uri defined on it. Is it worthwhile developing the core path or not? - if hasattr(conn_class, "from_uri"): - return conn_class.from_uri(conn_id=conn_id, uri=value) - return conn_class(conn_id=conn_id, uri=value) + if hasattr(self._connection_class, "from_uri"): + return self._connection_class.from_uri(conn_id=conn_id, uri=value) + return self._connection_class(conn_id=conn_id, uri=value) def get_connection(self, conn_id: str): """ diff --git a/task-sdk/src/airflow/sdk/configuration.py b/task-sdk/src/airflow/sdk/configuration.py index e60df4e13b734..b3fc00c406037 100644 --- a/task-sdk/src/airflow/sdk/configuration.py +++ b/task-sdk/src/airflow/sdk/configuration.py @@ -195,6 +195,7 @@ def initialize_secrets_backends( Uses SDK's conf instead of Core's conf. """ + from airflow.sdk.definitions.connection import Connection from airflow.sdk.module_loading import import_string backend_list = [] @@ -214,6 +215,9 @@ def initialize_secrets_backends( for class_name in default_backends: secrets_backend_cls = import_string(class_name) + if not hasattr(secrets_backend_cls, "set_connection_class"): + raise ValueError(f"{secrets_backend_cls} does not have set_connection_class method") + secrets_backend_cls.set_connection_class(Connection) backend_list.append(secrets_backend_cls()) return backend_list