Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -778,11 +778,13 @@ def _get_bool(val) -> bool | None:
class AsyncKubernetesHook(KubernetesHook):
"""Hook to use Kubernetes SDK asynchronously."""

def __init__(self, config_dict: dict | None = None, *args, **kwargs):
def __init__(
self, config_dict: dict | None = None, connection_extras: dict | None = None, *args, **kwargs
):
super().__init__(*args, **kwargs)

self.config_dict = config_dict
self._extras: dict | None = None
self._extras: dict | None = connection_extras
self._event_polling_fallback = False

async def _load_config(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@
from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, AirflowSkipException, TaskDeferred

if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseOperator
from airflow.sdk import BaseHook, BaseOperator
else:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined, no-redef]
from airflow.models import BaseOperator
from airflow.providers.common.compat.sdk import AirflowException

from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException
from airflow.settings import pod_mutation_hook
from airflow.utils import yaml
from airflow.utils.helpers import prune_dict, validate_key
Expand Down Expand Up @@ -868,13 +870,29 @@ def convert_config_file_to_dict(self):
def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None:
"""Redefine triggers which are being used in child classes."""
self.convert_config_file_to_dict()

connection_extras = None
if self.kubernetes_conn_id:
try:
conn = BaseHook.get_connection(self.kubernetes_conn_id)
except AirflowNotFoundException:
self.log.warning(
"Could not resolve connection extras for deferral: connection `%s` not found. "
"Triggerer will try to resolve it from its own environment.",
self.kubernetes_conn_id,
)
else:
connection_extras = conn.extra_dejson
self.log.info("Successfully resolved connection extras for deferral.")

trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
self.defer(
trigger=KubernetesPodTrigger(
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
trigger_start_time=trigger_start_time,
kubernetes_conn_id=self.kubernetes_conn_id,
connection_extras=connection_extras,
cluster_context=self.cluster_context,
config_dict=self._config_dict,
in_cluster=self.in_cluster,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
trigger_start_time: datetime.datetime,
base_container_name: str,
kubernetes_conn_id: str | None = None,
connection_extras: dict | None = None,
poll_interval: float = 2,
cluster_context: str | None = None,
config_dict: dict | None = None,
Expand All @@ -107,6 +108,7 @@ def __init__(
self.trigger_start_time = trigger_start_time
self.base_container_name = base_container_name
self.kubernetes_conn_id = kubernetes_conn_id
self.connection_extras = connection_extras
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_dict = config_dict
Expand All @@ -130,6 +132,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"pod_namespace": self.pod_namespace,
"base_container_name": self.base_container_name,
"kubernetes_conn_id": self.kubernetes_conn_id,
"connection_extras": self.connection_extras,
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_dict": self.config_dict,
Expand Down Expand Up @@ -324,6 +327,7 @@ def hook(self) -> AsyncKubernetesHook:
in_cluster=self.in_cluster,
config_dict=self.config_dict,
cluster_context=self.cluster_context,
connection_extras=self.connection_extras,
)

@cached_property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2342,8 +2342,16 @@ def run_pod_async(self, operator: KubernetesPodOperator, map_index: int = -1):
@patch(KUB_OP_PATH.format("find_pod"))
@patch(KUB_OP_PATH.format("build_pod_request_obj"))
@patch(KUB_OP_PATH.format("get_or_create_pod"))
@patch("airflow.providers.cncf.kubernetes.operators.pod.BaseHook.get_connection")
def test_async_create_pod_should_execute_successfully(
self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, do_xcom_push, mocker
self,
mocked_get_connection,
mocked_pod,
mocked_pod_obj,
mocked_found_pod,
mocked_client,
do_xcom_push,
mocker,
):
"""
Asserts that a task is deferred and the KubernetesCreatePodTrigger will be fired
Expand All @@ -2352,6 +2360,8 @@ def test_async_create_pod_should_execute_successfully(
pod name and namespace are *always* pushed; do_xcom_push only controls xcom sidecar
"""

mocked_get_connection.return_value.extra_dejson = {"foo": "bar"}

k = KubernetesPodOperator(
task_id=TEST_TASK_ID,
namespace=TEST_NAMESPACE,
Expand Down Expand Up @@ -2384,6 +2394,8 @@ def test_async_create_pod_should_execute_successfully(
ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
ti_mock.xcom_push.assert_any_call(key="pod_namespace", value=TEST_NAMESPACE)
assert isinstance(exc.value.trigger, KubernetesPodTrigger)
assert exc.value.trigger.connection_extras == {"foo": "bar"}
mocked_get_connection.assert_called_once_with(k.kubernetes_conn_id)

@pytest.mark.parametrize("status", ["error", "failed", "timeout"])
@patch(KUB_OP_PATH.format("log"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_serialize(self, trigger):
"pod_namespace": NAMESPACE,
"base_container_name": BASE_CONTAINER_NAME,
"kubernetes_conn_id": CONN_ID,
"connection_extras": None,
"poll_interval": POLL_INTERVAL,
"cluster_context": CLUSTER_CONTEXT,
"config_dict": CONFIG_DICT,
Expand All @@ -129,6 +130,52 @@ def test_serialize(self, trigger):
"trigger_kwargs": {},
}

def test_serialize_with_connection_extras(self):
extras = {"token": "abc"}
trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
pod_namespace=NAMESPACE,
base_container_name=BASE_CONTAINER_NAME,
kubernetes_conn_id=CONN_ID,
connection_extras=extras,
poll_interval=POLL_INTERVAL,
cluster_context=CLUSTER_CONTEXT,
config_dict=CONFIG_DICT,
in_cluster=IN_CLUSTER,
get_logs=GET_LOGS,
startup_timeout=STARTUP_TIMEOUT_SECS,
startup_check_interval=STARTUP_CHECK_INTERVAL_SECS,
schedule_timeout=STARTUP_TIMEOUT_SECS,
trigger_start_time=TRIGGER_START_TIME,
on_finish_action=ON_FINISH_ACTION,
)

_, kwargs_dict = trigger.serialize()

assert kwargs_dict["connection_extras"] == extras

def test_hook_uses_provided_connection_extras(self):
extras = {"token": "abc"}
trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
pod_namespace=NAMESPACE,
base_container_name=BASE_CONTAINER_NAME,
kubernetes_conn_id=CONN_ID,
connection_extras=extras,
poll_interval=POLL_INTERVAL,
cluster_context=CLUSTER_CONTEXT,
config_dict=CONFIG_DICT,
in_cluster=IN_CLUSTER,
get_logs=GET_LOGS,
startup_timeout=STARTUP_TIMEOUT_SECS,
startup_check_interval=STARTUP_CHECK_INTERVAL_SECS,
schedule_timeout=STARTUP_TIMEOUT_SECS,
trigger_start_time=TRIGGER_START_TIME,
on_finish_action=ON_FINISH_ACTION,
)

assert trigger.hook._extras == extras

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start")
async def test_run_loop_return_success_event(self, mock_wait_pod, trigger):
Expand Down