Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import json
import logging
import os
Expand Down Expand Up @@ -183,6 +184,7 @@ def test_do_xcom_push_defaults_false(self, kubeconfig_path, mock_get_connection,
)
assert not k.do_xcom_push

@pytest.mark.asyncio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you should need this -- It's likely needed to make asyncio.get_event_loop() pass, but a worker doesn't have it, and if we swap it to aysncio.run() it won't need this marker. I.e. adding this mark was working around the test failing in a way that is not respective of how the KPO actually runs?

def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path):
new_config_path = tmp_path / "kube_config.cfg"
shutil.copy(kubeconfig_path, new_config_path)
Expand All @@ -205,6 +207,7 @@ def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path):
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
assert actual_pod == expected_pod

@pytest.mark.asyncio
def test_working_pod(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -222,6 +225,7 @@ def test_working_pod(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_skip_cleanup(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="unknown",
Expand All @@ -237,6 +241,7 @@ def test_skip_cleanup(self, mock_get_connection):
with pytest.raises(ApiException):
k.execute(context)

@pytest.mark.asyncio
def test_delete_operator_pod(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -255,6 +260,7 @@ def test_delete_operator_pod(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_skip_on_specified_exit_code(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -271,6 +277,7 @@ def test_skip_on_specified_exit_code(self, mock_get_connection):
with pytest.raises(AirflowSkipException):
k.execute(context)

@pytest.mark.asyncio
def test_already_checked_on_success(self, mock_get_connection):
"""
When ``on_finish_action="keep_pod"``, pod should have 'already_checked'
Expand All @@ -293,6 +300,7 @@ def test_already_checked_on_success(self, mock_get_connection):
actual_pod = self.api_client.sanitize_for_serialization(actual_pod)
assert actual_pod["metadata"]["labels"]["already_checked"] == "True"

@pytest.mark.asyncio
def test_already_checked_on_failure(self, mock_get_connection):
"""
When ``on_finish_action="keep_pod"``, pod should have 'already_checked'
Expand All @@ -318,6 +326,7 @@ def test_already_checked_on_failure(self, mock_get_connection):
assert status["state"]["terminated"]["reason"] == "Error"
assert actual_pod["metadata"]["labels"]["already_checked"] == "True"

@pytest.mark.asyncio
def test_pod_hostnetwork(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -337,6 +346,7 @@ def test_pod_hostnetwork(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_pod_dnspolicy(self, mock_get_connection):
dns_policy = "ClusterFirstWithHostNet"
k = KubernetesPodOperator(
Expand All @@ -359,6 +369,7 @@ def test_pod_dnspolicy(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_pod_schedulername(self, mock_get_connection):
scheduler_name = "default-scheduler"
k = KubernetesPodOperator(
Expand All @@ -378,6 +389,7 @@ def test_pod_schedulername(self, mock_get_connection):
self.expected_pod["spec"]["schedulerName"] = scheduler_name
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_node_selector(self, mock_get_connection):
node_selector = {"beta.kubernetes.io/os": "linux"}
k = KubernetesPodOperator(
Expand All @@ -397,6 +409,7 @@ def test_pod_node_selector(self, mock_get_connection):
self.expected_pod["spec"]["nodeSelector"] = node_selector
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_resources(self, mock_get_connection):
resources = k8s.V1ResourceRequirements(
requests={"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"},
Expand All @@ -422,6 +435,7 @@ def test_pod_resources(self, mock_get_connection):
}
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
@pytest.mark.parametrize(
"val",
[
Expand Down Expand Up @@ -498,6 +512,7 @@ def test_pod_affinity(self, val, mock_get_connection):
self.expected_pod["spec"]["affinity"] = expected
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_port(self, mock_get_connection):
port = k8s.V1ContainerPort(
name="http",
Expand All @@ -521,6 +536,7 @@ def test_port(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["ports"] = [{"name": "http", "containerPort": 80}]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_volume_mount(self, mock_get_connection):
with mock.patch.object(PodManager, "log") as mock_logger:
volume_mount = k8s.V1VolumeMount(
Expand Down Expand Up @@ -560,6 +576,7 @@ def test_volume_mount(self, mock_get_connection):
]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
@pytest.mark.parametrize("uid", [0, 1000])
def test_run_as_user(self, uid, mock_get_connection):
security_context = V1PodSecurityContext(run_as_user=uid)
Expand All @@ -585,6 +602,7 @@ def test_run_as_user(self, uid, mock_get_connection):
)
assert pod.to_dict()["spec"]["security_context"]["run_as_user"] == uid

@pytest.mark.asyncio
@pytest.mark.parametrize("gid", [0, 1000])
def test_fs_group(self, gid, mock_get_connection):
security_context = V1PodSecurityContext(fs_group=gid)
Expand All @@ -610,6 +628,7 @@ def test_fs_group(self, gid, mock_get_connection):
)
assert pod.to_dict()["spec"]["security_context"]["fs_group"] == gid

@pytest.mark.asyncio
def test_disable_privilege_escalation(self, mock_get_connection):
container_security_context = V1SecurityContext(allow_privilege_escalation=False)

Expand All @@ -632,6 +651,7 @@ def test_disable_privilege_escalation(self, mock_get_connection):
}
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_faulty_image(self, mock_get_connection):
bad_image_name = "foobar"
k = KubernetesPodOperator(
Expand Down Expand Up @@ -670,6 +690,7 @@ def test_faulty_service_account(self, mock_get_connection):
with pytest.raises(ApiException, match="error looking up service account default/foobar"):
k.get_or_create_pod(pod, context)

@pytest.mark.asyncio
def test_pod_failure(self, mock_get_connection):
"""
Tests that the task fails when a pod reports a failure
Expand All @@ -692,6 +713,7 @@ def test_pod_failure(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["args"] = bad_internal_command
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_xcom_push(self, test_label, mock_get_connection):
expected = {"test_label": test_label, "buzz": 2}
args = [f"echo '{json.dumps(expected)}' > /airflow/xcom/return.json"]
Expand Down Expand Up @@ -740,6 +762,7 @@ def test_env_vars(self, mock_get_connection):
]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_template_file_system(self, mock_get_connection, basic_pod_template):
"""Note: this test requires that you have a namespace ``mem-example`` in your cluster."""
k = KubernetesPodOperator(
Expand All @@ -755,6 +778,7 @@ def test_pod_template_file_system(self, mock_get_connection, basic_pod_template)
assert result is not None
assert result == {"hello": "world"}

@pytest.mark.asyncio
@pytest.mark.parametrize(
"env_vars",
[
Expand Down Expand Up @@ -790,6 +814,7 @@ def test_pod_template_file_with_overrides_system(
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

@pytest.mark.asyncio
def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connection, basic_pod_template):
pod_spec = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
Expand Down Expand Up @@ -830,6 +855,7 @@ def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connect
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

@pytest.mark.asyncio
def test_full_pod_spec(self, test_label, mock_get_connection):
pod_spec = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
Expand Down Expand Up @@ -875,6 +901,7 @@ def test_full_pod_spec(self, test_label, mock_get_connection):
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

@pytest.mark.asyncio
def test_init_container(self, mock_get_connection):
# GIVEN
volume_mounts = [
Expand Down Expand Up @@ -929,6 +956,7 @@ def test_init_container(self, mock_get_connection):
]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
@mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start")
@mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom")
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
Expand Down Expand Up @@ -1036,6 +1064,7 @@ def test_pod_template_file(
del actual_pod["metadata"]["labels"]["airflow_version"]
assert expected_dict == actual_pod

@pytest.mark.asyncio
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS)
Expand Down Expand Up @@ -1070,6 +1099,7 @@ def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock):
self.expected_pod["spec"]["priorityClassName"] = priority_class_name
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_name(self, mock_get_connection):
pod_name_too_long = "a" * 221
k = KubernetesPodOperator(
Expand All @@ -1089,6 +1119,7 @@ def test_pod_name(self, mock_get_connection):
with pytest.raises(AirflowException):
k.execute(context)

@pytest.mark.asyncio
def test_on_kill(self, mock_get_connection):
hook = KubernetesHook(conn_id=None, in_cluster=False)
client = hook.core_v1_client
Expand Down Expand Up @@ -1129,6 +1160,7 @@ class ShortCircuitException(Exception):
with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'):
client.read_namespaced_pod(name=name, namespace=namespace)

@pytest.mark.asyncio
def test_reattach_failing_pod_once(self, mock_get_connection):
hook = KubernetesHook(conn_id=None, in_cluster=False)
client = hook.core_v1_client
Expand Down Expand Up @@ -1189,13 +1221,19 @@ def get_op():
# recreate op just to ensure we're not relying on any statefulness
k = get_op()

# Before next attempt we need to re-create event loop if it is closed.
loop = asyncio.get_event_loop()
if loop.is_closed():
asyncio.set_event_loop(asyncio.new_event_loop())

# `create_pod` should be called because though there's still a pod to be found,
# it will be `already_checked`
with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock:
with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'):
k.execute(context)
create_mock.assert_called_once()

@pytest.mark.asyncio
def test_changing_base_container_name_with_get_logs(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -1221,6 +1259,7 @@ def test_changing_base_container_name_with_get_logs(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
assert self.expected_pod["spec"] == actual_pod["spec"]

@pytest.mark.asyncio
def test_changing_base_container_name_no_logs(self, mock_get_connection):
"""
This test checks BOTH a modified base container name AND the get_logs=False flow,
Expand Down Expand Up @@ -1251,6 +1290,7 @@ def test_changing_base_container_name_no_logs(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
assert self.expected_pod["spec"] == actual_pod["spec"]

@pytest.mark.asyncio
def test_changing_base_container_name_no_logs_long(self, mock_get_connection):
"""
Similar to test_changing_base_container_name_no_logs, but ensures that
Expand Down Expand Up @@ -1282,6 +1322,7 @@ def test_changing_base_container_name_no_logs_long(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["args"] = ["sleep 3"]
assert self.expected_pod["spec"] == actual_pod["spec"]

@pytest.mark.asyncio
def test_changing_base_container_name_failure(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand Down Expand Up @@ -1328,6 +1369,7 @@ class MyK8SPodOperator(KubernetesPodOperator):
)
assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce"

@pytest.mark.asyncio
def test_init_container_logs(self, mock_get_connection):
marker_from_init_container = f"{uuid4()}"
marker_from_main_container = f"{uuid4()}"
Expand Down Expand Up @@ -1359,6 +1401,7 @@ def test_init_container_logs(self, mock_get_connection):
assert marker_from_init_container in calls_args
assert marker_from_main_container in calls_args

@pytest.mark.asyncio
def test_init_container_logs_filtered(self, mock_get_connection):
marker_from_init_container_to_log_1 = f"{uuid4()}"
marker_from_init_container_to_log_2 = f"{uuid4()}"
Expand Down Expand Up @@ -1456,6 +1499,7 @@ def __getattr__(self, name):


class TestKubernetesPodOperator(BaseK8STest):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"active_deadline_seconds,should_fail",
[(3, True), (60, False)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import asyncio
import datetime
import json
import logging
Expand Down Expand Up @@ -602,12 +603,20 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s

def await_pod_start(self, pod: k8s.V1Pod) -> None:
try:
self.pod_manager.await_pod_start(
pod=pod,
schedule_timeout=self.schedule_timeout_seconds,
startup_timeout=self.startup_timeout_seconds,
check_interval=self.startup_check_interval_seconds,
loop = asyncio.get_event_loop()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically this: This is running in a normal sync worker there is no running even loop and this raises an exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, per the docs:

If there is no running event loop set, the function will return the result of the get_event_loop_policy().get_event_loop() call.

Copy link
Member

@ashb ashb Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I wonder if it's related to this, and some user code turning that into an Error.

Deprecated since version 3.12: Deprecation warning is emitted if there is no current event loop. In some future Python release this will become an error.

Trying to confirm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what in parallel execution in a sync task is then the right way. For us it is working like this since a year in production. But maybe our environment is not representative?

Challenge is to parse and follow logs and events in parallel. There is no API in K8s delivering both concurrently and flipping back-and forth is very in efficient if you want to listen to log stream. Therefore we took the async approach.

Do you have more details on where and how it is "breaking"? Which environment?

Note that we also initially attempted to run another thread and not using asyncio but this also was blocked by Celery and is probably also not advised.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have many details -- its second hand through one of the Astronomer customers.

I think this is something turning the deprecation warning in to an error -- I think the fix/workaround is to swap the manual loop control to asyncio.run as per this in the docs

Application developers should typically use the high-level asyncio functions, such as asyncio.run(), and should rarely need to reference the loop object or call its methods. This section is intended mostly for authors of lower-level code, libraries, and frameworks, who need finer control over the event loop behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a full report including a way to reproduce as a GH issue ticket on this? Would be great as well also to include a test to prevent regression then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to see what I can get. All I have right now is the stacktrace I put in a comment above.

events_task = asyncio.ensure_future(
self.pod_manager.watch_pod_events(pod, self.startup_check_interval_seconds)
)
loop.run_until_complete(
self.pod_manager.await_pod_start(
pod=pod,
schedule_timeout=self.schedule_timeout_seconds,
startup_timeout=self.startup_timeout_seconds,
check_interval=self.startup_check_interval_seconds,
)
)
loop.run_until_complete(events_task)
loop.close()
except PodLaunchFailedException:
if self.log_events_on_failure:
self._read_pod_events(pod, reraise=False)
Expand Down
Loading