Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
if TYPE_CHECKING:
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import FailureDetails

from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import KubernetesResults
from kubernetes_tests.test_base import (
EXECUTOR,
BaseK8STest, # isort:skip (needed to workaround isort bug)
Expand Down Expand Up @@ -138,15 +139,19 @@ def test_pod_failure_logging_with_container_terminated(self, mock_log):
# Create a test task key
task_key = TaskInstanceKey(dag_id="test_dag", task_id="test_task", run_id="test_run", try_number=1)

# Call _change_state with FAILED status and failure details
executor._change_state(
# Create KubernetesResults object
results = KubernetesResults(
key=task_key,
state=TaskInstanceState.FAILED,
pod_name="test-pod",
namespace="test-namespace",
resource_version="123",
failure_details=failure_details,
)

# Call _change_state with KubernetesResults object
executor._change_state(results)

# Verify that the warning log was called with expected parameters
mock_log.warning.assert_called_once_with(
"Task %s failed in pod %s/%s. Pod phase: %s, reason: %s, message: %s, "
Expand Down Expand Up @@ -181,15 +186,19 @@ def test_pod_failure_logging_exception_handling(self, mock_log):
# Create a test task key
task_key = TaskInstanceKey(dag_id="test_dag", task_id="test_task", run_id="test_run", try_number=1)

# Call _change_state with FAILED status but no failure details
executor._change_state(
# Create KubernetesResults object without failure details
results = KubernetesResults(
key=task_key,
state=TaskInstanceState.FAILED,
pod_name="test-pod",
namespace="test-namespace",
resource_version="123",
failure_details=None,
)

# Call _change_state with KubernetesResults object
executor._change_state(results)

# Verify that the warning log was called with the correct parameters
mock_log.warning.assert_called_once_with(
"Task %s failed in pod %s/%s (no details available)",
Expand All @@ -214,11 +223,19 @@ def test_pod_failure_logging_non_failed_state(self, mock_log):
# Create a test task key
task_key = TaskInstanceKey(dag_id="test_dag", task_id="test_task", run_id="test_run", try_number=1)

# Call _change_state with SUCCESS status
executor._change_state(
key=task_key, state=TaskInstanceState.SUCCESS, pod_name="test-pod", namespace="test-namespace"
# Create KubernetesResults object with SUCCESS state
results = KubernetesResults(
key=task_key,
state=TaskInstanceState.SUCCESS,
pod_name="test-pod",
namespace="test-namespace",
resource_version="123",
failure_details=None,
)

# Call _change_state with KubernetesResults object
executor._change_state(results)

# Verify that no failure logs were called
mock_log.error.assert_not_called()
mock_log.warning.assert_not_called()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
ADOPTED,
POD_EXECUTOR_DONE_KEY,
FailureDetails,
KubernetesJob,
KubernetesResults,
)
from airflow.providers.cncf.kubernetes.kube_config import KubeConfig
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annotations_to_key
Expand All @@ -86,10 +87,6 @@
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
KubernetesJobType,
KubernetesResultsType,
)
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import (
AirflowKubernetesScheduler,
)
Expand Down Expand Up @@ -157,8 +154,8 @@ class KubernetesExecutor(BaseExecutor):
def __init__(self):
self.kube_config = KubeConfig()
self._manager = multiprocessing.Manager()
self.task_queue: Queue[KubernetesJobType] = self._manager.Queue()
self.result_queue: Queue[KubernetesResultsType] = self._manager.Queue()
self.task_queue: Queue[KubernetesJob] = self._manager.Queue()
self.result_queue: Queue[KubernetesResults] = self._manager.Queue()
self.kube_scheduler: AirflowKubernetesScheduler | None = None
self.kube_client: client.CoreV1Api | None = None
self.scheduler_job_id: str | None = None
Expand Down Expand Up @@ -280,7 +277,7 @@ def execute_async(
else:
pod_template_file = None
self.event_buffer[key] = (TaskInstanceState.QUEUED, self.scheduler_job_id)
self.task_queue.put((key, command, kube_executor_config, pod_template_file))
self.task_queue.put(KubernetesJob(key, command, kube_executor_config, pod_template_file))
# We keep a temporary local record that we've handled this so we don't
# try and remove it from the QUEUED state while we process it
self.last_handled[key] = time.time()
Expand Down Expand Up @@ -331,17 +328,16 @@ def sync(self) -> None:
while True:
results = self.result_queue.get_nowait()
try:
key, state, pod_name, namespace, resource_version, failure_details = results
last_resource_version[namespace] = resource_version
self.log.info("Changing state of %s to %s", results, state)
last_resource_version[results.namespace] = results.resource_version
self.log.info("Changing state of %s to %s", results, results.state)
try:
self._change_state(key, state, pod_name, namespace, failure_details)
self._change_state(results)
except Exception as e:
self.log.exception(
"Exception: %s when attempting to change state of %s to %s, re-queueing.",
e,
results,
state,
results.state,
)
self.result_queue.put(results)
finally:
Expand All @@ -362,7 +358,7 @@ def sync(self) -> None:
task = self.task_queue.get_nowait()

try:
key, command, kube_executor_config, pod_template_file = task
key = task.key
self.kube_scheduler.run_next(task)
self.task_publish_retries.pop(key, None)
except PodReconciliationError as e:
Expand Down Expand Up @@ -391,11 +387,11 @@ def sync(self) -> None:
self.task_publish_retries[key] = retries + 1
else:
self.log.error("Pod creation failed with reason %r. Failing task", e.reason)
key, _, _, _ = task
key = task.key
self.fail(key, e)
self.task_publish_retries.pop(key, None)
except PodMutationHookException as e:
key, _, _, _ = task
key = task.key
self.log.error(
"Pod Mutation Hook failed for the task %s. Failing task. Details: %s",
key,
Expand All @@ -408,16 +404,19 @@ def sync(self) -> None:
@provide_session
def _change_state(
self,
key: TaskInstanceKey,
state: TaskInstanceState | str | None,
pod_name: str,
namespace: str,
failure_details: FailureDetails | None = None,
results: KubernetesResults,
session: Session = NEW_SESSION,
) -> None:
"""Change state of the task based on KubernetesResults."""
if TYPE_CHECKING:
assert self.kube_scheduler

key = results.key
state = results.state
pod_name = results.pod_name
namespace = results.namespace
failure_details = results.failure_details

if state == TaskInstanceState.FAILED:
# Use pre-collected failure details from the watcher to avoid additional API calls
if failure_details:
Expand Down Expand Up @@ -734,18 +733,20 @@ def _flush_result_queue(self) -> None:
results = self.result_queue.get_nowait()
self.log.warning("Executor shutting down, flushing results=%s", results)
try:
key, state, pod_name, namespace, resource_version, failure_details = results
self.log.info(
"Changing state of %s to %s : resource_version=%d", results, state, resource_version
"Changing state of %s to %s : resource_version=%s",
results,
results.state,
results.resource_version,
)
try:
self._change_state(key, state, pod_name, namespace, failure_details)
self._change_state(results)
except Exception as e:
self.log.exception(
"Ignoring exception: %s when attempting to change state of %s to %s.",
e,
results,
state,
results.state,
)
finally:
self.result_queue.task_done()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, TypedDict
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypedDict

if TYPE_CHECKING:
from collections.abc import Sequence

from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.state import TaskInstanceState


ADOPTED = "adopted"

Expand All @@ -35,27 +42,40 @@ class FailureDetails(TypedDict, total=False):
container_name: str | None


if TYPE_CHECKING:
from collections.abc import Sequence
class KubernetesResults(NamedTuple):
"""Results from Kubernetes task execution."""

from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.state import TaskInstanceState
key: TaskInstanceKey
state: TaskInstanceState | str | None
pod_name: str
namespace: str
resource_version: str
failure_details: FailureDetails | None


class KubernetesWatch(NamedTuple):
"""Watch event data from Kubernetes pods."""

pod_name: str
namespace: str
state: TaskInstanceState | str | None
annotations: dict[str, str]
resource_version: str
failure_details: FailureDetails | None


# TODO: Remove after Airflow 2 support is removed
CommandType = "Sequence[str]"

# TODO: Remove after Airflow 2 support is removed
CommandType = Sequence[str]

# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = tuple[TaskInstanceKey, CommandType, Any, str | None]
class KubernetesJob(NamedTuple):
"""Job definition for Kubernetes execution."""

# key, pod state, pod_name, namespace, resource_version, failure_details
KubernetesResultsType = tuple[
TaskInstanceKey, TaskInstanceState | str | None, str, str, str, FailureDetails | None
]
key: TaskInstanceKey
command: Sequence[str]
kube_executor_config: Any
pod_template_file: str | None

# pod_name, namespace, pod state, annotations, resource_version, failure_details
KubernetesWatchType = tuple[
str, str, TaskInstanceState | str | None, dict[str, str], str, FailureDetails | None
]

ALL_NAMESPACES = "ALL_NAMESPACES"
POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
Expand Down
Loading