diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index a01f05ae82de6..bb75f7975252c 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -248,23 +248,32 @@ def find_spark_job(self, context, exclude_checked: bool = True): self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) + ",spark-role=driver" ) - pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items + # since we did not specify a resource version, we make sure to get the latest data + # we make sure we get only running or pending pods. + field_selector = self._get_field_selector() + pod_list = self.client.list_namespaced_pod( + self.namespace, label_selector=label_selector, field_selector=field_selector + ).items pod = None if len(pod_list) > 1: - # When multiple pods match the same labels, select one deterministically, - # preferring a Running pod, then creation time, with name as a tie-breaker. + # When multiple pods match the same labels, select one deterministically. + # Prefer Succeeded, then Running (excluding terminating), then Pending. + # Terminating pods can be identified via deletion_timestamp. + # Pending pods are included to handle recent driver restarts without failing the task. pod = max( pod_list, key=lambda p: ( - p.status.phase == PodPhase.RUNNING, + p.metadata.deletion_timestamp is None, # not a terminating pod in pending + p.status.phase == PodPhase.SUCCEEDED, # if the job succeeded while the worker was down + p.status.phase == PodPhase.PENDING, p.metadata.creation_timestamp or datetime.min.replace(tzinfo=timezone.utc), p.metadata.name or "", ), ) self.log.warning( "Found %d Spark driver pods matching labels %s; " - "selecting pod %s for reattachment based on status and creation time.", + "selecting pod %s for reattachment based on status.", len(pod_list), label_selector, pod.metadata.name, @@ -279,6 +288,10 @@ def find_spark_job(self, context, exclude_checked: bool = True): self.log.info("`try_number` of pod: %s", pod.metadata.labels.get("try_number", "unknown")) return pod + def _get_field_selector(self) -> str: + # exclude terminal failure states, to get only running, pending and succeeded states. + return f"status.phase!={PodPhase.FAILED},status.phase!={PodPhase.UNKNOWN}" + def process_pod_deletion(self, pod, *, reraise=True): if pod is not None: if self.delete_on_termination: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index cd25f06a76eaf..01f8f4a3fa2c6 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -90,6 +90,7 @@ class PodPhase: RUNNING = "Running" FAILED = "Failed" SUCCEEDED = "Succeeded" + UNKNOWN = "Unknown" terminal_states = {FAILED, SUCCEEDED} diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py index 1e7f5c1da23d4..31c3881138504 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -863,7 +863,11 @@ def test_find_custom_pod_labels( op.execute(context) label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver" op.find_spark_job(context) - mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + mock_get_kube_client.list_namespaced_pod.assert_called_with( + "default", + label_selector=label_selector, + field_selector=op._get_field_selector(), + ) @patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook") def test_adds_task_context_labels_to_driver_and_executor( @@ -941,7 +945,11 @@ def test_reattach_on_restart_with_task_context_labels( op.execute(context) label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver" - mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + mock_get_kube_client.list_namespaced_pod.assert_called_with( + "default", + label_selector=label_selector, + field_selector=op._get_field_selector(), + ) mock_create_namespaced_crd.assert_not_called() @@ -983,21 +991,140 @@ def test_find_spark_job_picks_running_pod( running_pod.metadata.labels = {"try_number": "1"} running_pod.status.phase = "Running" - # Pending pod should not be selected. + # Terminating pod should not be selected. + terminating_pod = mock.MagicMock() + terminating_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + terminating_pod.metadata.deletion_timestamp = timezone.datetime(2025, 1, 2, tzinfo=timezone.utc) + terminating_pod.metadata.name = "spark-driver-pending" + terminating_pod.metadata.labels = {"try_number": "1"} + terminating_pod.status.phase = "Running" + + mock_get_kube_client.list_namespaced_pod.return_value.items = [ + running_pod, + terminating_pod, + ] + + returned_pod = op.find_spark_job(context) + + assert returned_pod is running_pod + + def test_find_spark_job_picks_pending_pod( + self, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + """ + Verifies that find_spark_job picks a Pending Spark driver pod over a Terminating. + """ + + task_name = "test_find_spark_job_prefers_pending_pod" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + + # Pending pod should be selected. pending_pod = mock.MagicMock() pending_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) - pending_pod.metadata.name = "spark-driver-pending" + pending_pod.metadata.name = "spark-driver" pending_pod.metadata.labels = {"try_number": "1"} pending_pod.status.phase = "Pending" + # Terminating pod should not be selected. + terminating_pod = mock.MagicMock() + terminating_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + terminating_pod.metadata.deletion_timestamp = timezone.datetime(2025, 1, 2, tzinfo=timezone.utc) + terminating_pod.metadata.name = "spark-driver" + terminating_pod.metadata.labels = {"try_number": "1"} + terminating_pod.status.phase = "Running" + mock_get_kube_client.list_namespaced_pod.return_value.items = [ - running_pod, + terminating_pod, # comes first but should be ignored, as it is terminating pending_pod, ] returned_pod = op.find_spark_job(context) - assert returned_pod is running_pod + assert returned_pod is pending_pod + + def test_find_spark_job_picks_succeeded( + self, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + """ + Verifies that find_spark_job picks a Succeeded Spark driver pod over a non-Running pod. + """ + + task_name = "test_find_spark_job_prefers_succeeded_pod" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + + # Succeeded pod should be selected. + succeeded_pod = mock.MagicMock() + succeeded_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + succeeded_pod.metadata.name = "spark-driver" + succeeded_pod.metadata.labels = {"try_number": "1"} + succeeded_pod.status.phase = "Succeeded" + + # Running pod should be selected. + running_pod = mock.MagicMock() + running_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + running_pod.metadata.name = "spark-driver" + running_pod.metadata.labels = {"try_number": "1"} + running_pod.status.phase = "Running" + + # Terminating pod should not be selected. + terminating_pod = mock.MagicMock() + terminating_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + terminating_pod.metadata.deletion_timestamp = timezone.datetime(2025, 1, 2, tzinfo=timezone.utc) + terminating_pod.metadata.name = "spark-driver" + terminating_pod.metadata.labels = {"try_number": "1"} + terminating_pod.status.phase = "Running" + + mock_get_kube_client.list_namespaced_pod.return_value.items = [ + terminating_pod, + running_pod, + succeeded_pod, + ] + + returned_pod = op.find_spark_job(context) + + assert returned_pod is succeeded_pod def test_find_spark_job_picks_latest_pod( self, @@ -1029,30 +1156,31 @@ def test_find_spark_job_picks_latest_pod( get_logs=True, reattach_on_restart=True, ) - context = create_context(op) - # Older pod that should be ignored. - old_mock_pod = mock.MagicMock() - old_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) - old_mock_pod.metadata.name = "spark-driver-old" - old_mock_pod.status.phase = PodPhase.RUNNING + context = create_context(op) - # Newer pod that should be picked up. - new_mock_pod = mock.MagicMock() - new_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 2, tzinfo=timezone.utc) - new_mock_pod.metadata.name = "spark-driver-new" - new_mock_pod.status.phase = PodPhase.RUNNING + # Latest pod should be selected. + new_pod = mock.MagicMock() + new_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 3, tzinfo=timezone.utc) + new_pod.metadata.name = "spark-driver" + new_pod.metadata.labels = {"try_number": "1"} + new_pod.status.phase = "Pending" - # Same try_number to simulate abrupt failure scenarios (e.g. scheduler crash) - # where cleanup did not occur and multiple pods share identical labels. - old_mock_pod.metadata.labels = {"try_number": "1"} - new_mock_pod.metadata.labels = {"try_number": "1"} + # Older pod should not be selected. + old_pod = mock.MagicMock() + old_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc) + old_pod.metadata.name = "spark-driver" + old_pod.metadata.labels = {"try_number": "1"} + old_pod.status.phase = "Running" - mock_get_kube_client.list_namespaced_pod.return_value.items = [old_mock_pod, new_mock_pod] + mock_get_kube_client.list_namespaced_pod.return_value.items = [ + old_pod, + new_pod, + ] returned_pod = op.find_spark_job(context) - assert returned_pod is new_mock_pod + assert returned_pod is new_pod def test_find_spark_job_tiebreaks_by_name( self,