diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 1f1ab896a0973..47ef247fac6c2 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -778,11 +778,13 @@ def _get_bool(val) -> bool | None: class AsyncKubernetesHook(KubernetesHook): """Hook to use Kubernetes SDK asynchronously.""" - def __init__(self, config_dict: dict | None = None, *args, **kwargs): + def __init__( + self, config_dict: dict | None = None, connection_extras: dict | None = None, *args, **kwargs + ): super().__init__(*args, **kwargs) self.config_dict = config_dict - self._extras: dict | None = None + self._extras: dict | None = connection_extras self._event_polling_fallback = False async def _load_config(self): diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index 3337fc7749071..9be27560a088e 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -81,10 +81,12 @@ from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, AirflowSkipException, TaskDeferred if AIRFLOW_V_3_1_PLUS: - from airflow.sdk import BaseOperator + from airflow.sdk import BaseHook, BaseOperator else: + from airflow.hooks.base import BaseHook # type: ignore[attr-defined, no-redef] from airflow.models import BaseOperator -from airflow.providers.common.compat.sdk import AirflowException + +from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException from airflow.settings import pod_mutation_hook from airflow.utils import yaml from airflow.utils.helpers import prune_dict, validate_key @@ -868,6 +870,21 @@ def convert_config_file_to_dict(self): def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None: """Redefine triggers which are being used in child classes.""" self.convert_config_file_to_dict() + + connection_extras = None + if self.kubernetes_conn_id: + try: + conn = BaseHook.get_connection(self.kubernetes_conn_id) + except AirflowNotFoundException: + self.log.warning( + "Could not resolve connection extras for deferral: connection `%s` not found. " + "Triggerer will try to resolve it from its own environment.", + self.kubernetes_conn_id, + ) + else: + connection_extras = conn.extra_dejson + self.log.info("Successfully resolved connection extras for deferral.") + trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc) self.defer( trigger=KubernetesPodTrigger( @@ -875,6 +892,7 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None: pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr] trigger_start_time=trigger_start_time, kubernetes_conn_id=self.kubernetes_conn_id, + connection_extras=connection_extras, cluster_context=self.cluster_context, config_dict=self._config_dict, in_cluster=self.in_cluster, diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py index 5aa496a772aa4..d5c8c59520c30 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -88,6 +88,7 @@ def __init__( trigger_start_time: datetime.datetime, base_container_name: str, kubernetes_conn_id: str | None = None, + connection_extras: dict | None = None, poll_interval: float = 2, cluster_context: str | None = None, config_dict: dict | None = None, @@ -107,6 +108,7 @@ def __init__( self.trigger_start_time = trigger_start_time self.base_container_name = base_container_name self.kubernetes_conn_id = kubernetes_conn_id + self.connection_extras = connection_extras self.poll_interval = poll_interval self.cluster_context = cluster_context self.config_dict = config_dict @@ -130,6 +132,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "pod_namespace": self.pod_namespace, "base_container_name": self.base_container_name, "kubernetes_conn_id": self.kubernetes_conn_id, + "connection_extras": self.connection_extras, "poll_interval": self.poll_interval, "cluster_context": self.cluster_context, "config_dict": self.config_dict, @@ -324,6 +327,7 @@ def hook(self) -> AsyncKubernetesHook: in_cluster=self.in_cluster, config_dict=self.config_dict, cluster_context=self.cluster_context, + connection_extras=self.connection_extras, ) @cached_property diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py index 0a16c9920fe73..67464dde10180 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py @@ -2342,8 +2342,16 @@ def run_pod_async(self, operator: KubernetesPodOperator, map_index: int = -1): @patch(KUB_OP_PATH.format("find_pod")) @patch(KUB_OP_PATH.format("build_pod_request_obj")) @patch(KUB_OP_PATH.format("get_or_create_pod")) + @patch("airflow.providers.cncf.kubernetes.operators.pod.BaseHook.get_connection") def test_async_create_pod_should_execute_successfully( - self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, do_xcom_push, mocker + self, + mocked_get_connection, + mocked_pod, + mocked_pod_obj, + mocked_found_pod, + mocked_client, + do_xcom_push, + mocker, ): """ Asserts that a task is deferred and the KubernetesCreatePodTrigger will be fired @@ -2352,6 +2360,8 @@ def test_async_create_pod_should_execute_successfully( pod name and namespace are *always* pushed; do_xcom_push only controls xcom sidecar """ + mocked_get_connection.return_value.extra_dejson = {"foo": "bar"} + k = KubernetesPodOperator( task_id=TEST_TASK_ID, namespace=TEST_NAMESPACE, @@ -2384,6 +2394,8 @@ def test_async_create_pod_should_execute_successfully( ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME) ti_mock.xcom_push.assert_any_call(key="pod_namespace", value=TEST_NAMESPACE) assert isinstance(exc.value.trigger, KubernetesPodTrigger) + assert exc.value.trigger.connection_extras == {"foo": "bar"} + mocked_get_connection.assert_called_once_with(k.kubernetes_conn_id) @pytest.mark.parametrize("status", ["error", "failed", "timeout"]) @patch(KUB_OP_PATH.format("log")) diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py index f2ce3b00dc3eb..98fb9a79fb8b0 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py @@ -114,6 +114,7 @@ def test_serialize(self, trigger): "pod_namespace": NAMESPACE, "base_container_name": BASE_CONTAINER_NAME, "kubernetes_conn_id": CONN_ID, + "connection_extras": None, "poll_interval": POLL_INTERVAL, "cluster_context": CLUSTER_CONTEXT, "config_dict": CONFIG_DICT, @@ -129,6 +130,52 @@ def test_serialize(self, trigger): "trigger_kwargs": {}, } + def test_serialize_with_connection_extras(self): + extras = {"token": "abc"} + trigger = KubernetesPodTrigger( + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + base_container_name=BASE_CONTAINER_NAME, + kubernetes_conn_id=CONN_ID, + connection_extras=extras, + poll_interval=POLL_INTERVAL, + cluster_context=CLUSTER_CONTEXT, + config_dict=CONFIG_DICT, + in_cluster=IN_CLUSTER, + get_logs=GET_LOGS, + startup_timeout=STARTUP_TIMEOUT_SECS, + startup_check_interval=STARTUP_CHECK_INTERVAL_SECS, + schedule_timeout=STARTUP_TIMEOUT_SECS, + trigger_start_time=TRIGGER_START_TIME, + on_finish_action=ON_FINISH_ACTION, + ) + + _, kwargs_dict = trigger.serialize() + + assert kwargs_dict["connection_extras"] == extras + + def test_hook_uses_provided_connection_extras(self): + extras = {"token": "abc"} + trigger = KubernetesPodTrigger( + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + base_container_name=BASE_CONTAINER_NAME, + kubernetes_conn_id=CONN_ID, + connection_extras=extras, + poll_interval=POLL_INTERVAL, + cluster_context=CLUSTER_CONTEXT, + config_dict=CONFIG_DICT, + in_cluster=IN_CLUSTER, + get_logs=GET_LOGS, + startup_timeout=STARTUP_TIMEOUT_SECS, + startup_check_interval=STARTUP_CHECK_INTERVAL_SECS, + schedule_timeout=STARTUP_TIMEOUT_SECS, + trigger_start_time=TRIGGER_START_TIME, + on_finish_action=ON_FINISH_ACTION, + ) + + assert trigger.hook._extras == extras + @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") async def test_run_loop_return_success_event(self, mock_wait_pod, trigger):