diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index c5fb8a6d86e66..c1f92af00374a 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -66,7 +66,9 @@ class SparkKubernetesOperator(KubernetesPodOperator): :param success_run_history_limit: Number of past successful runs of the application to keep. :param startup_timeout_seconds: timeout in seconds to startup the pod. :param log_events_on_failure: Log the pod's events if a failure occurs - :param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor + :param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor. + When enabled, the operator automatically adds Airflow task context labels (dag_id, task_id, run_id) + to the driver and executor pods to enable finding them for reattachment. :param delete_on_termination: What to do when the pod reaches its final state, or the execution is interrupted. If True (default), delete the pod; if False, leave the pod. @@ -203,17 +205,16 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool "spark_kubernetes_operator": "True", } - # If running on Airflow 2.3+: - map_index = getattr(ti, "map_index", -1) - if map_index >= 0: - labels["map_index"] = map_index + map_index = ti.map_index + if map_index is not None and map_index >= 0: + labels["map_index"] = str(map_index) if include_try_number: - labels.update(try_number=ti.try_number) + labels.update(try_number=str(ti.try_number)) # In the case of sub dags this is just useful # TODO: Remove this when the minimum version of Airflow is bumped to 3.0 - if getattr(context_dict["dag"], "is_subdag", False): + if getattr(context_dict["dag"], "parent_dag", False): labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. @@ -226,9 +227,11 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool def pod_manager(self) -> PodManager: return PodManager(kube_client=self.client) - @staticmethod - def _try_numbers_match(context, pod) -> bool: - return pod.metadata.labels["try_number"] == context["ti"].try_number + def _try_numbers_match(self, context, pod) -> bool: + task_instance = context["task_instance"] + task_context_labels = self._get_ti_pod_labels(context) + pod_try_number = pod.metadata.labels.get(task_context_labels.get("try_number", ""), "") + return str(task_instance.try_number) == str(pod_try_number) @property def template_body(self): @@ -251,20 +254,9 @@ def find_spark_job(self, context, exclude_checked: bool = True): "Found matching driver pod %s with labels %s", pod.metadata.name, pod.metadata.labels ) self.log.info("`try_number` of task_instance: %s", context["ti"].try_number) - self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"]) + self.log.info("`try_number` of pod: %s", pod.metadata.labels.get("try_number", "unknown")) return pod - def get_or_create_spark_crd(self, context) -> k8s.V1Pod: - if self.reattach_on_restart: - driver_pod = self.find_spark_job(context) - if driver_pod: - return driver_pod - - driver_pod, spark_obj_spec = self.launcher.start_spark_job( - image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds - ) - return driver_pod - def process_pod_deletion(self, pod, *, reraise=True): if pod is not None: if self.delete_on_termination: @@ -294,25 +286,79 @@ def client(self) -> CoreV1Api: def custom_obj_api(self) -> CustomObjectsApi: return CustomObjectsApi() - @cached_property - def launcher(self) -> CustomObjectLauncher: - launcher = CustomObjectLauncher( - name=self.name, - namespace=self.namespace, - kube_client=self.client, - custom_obj_api=self.custom_obj_api, - template_body=self.template_body, + def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod: + if self.reattach_on_restart: + driver_pod = self.find_spark_job(context) + if driver_pod: + return driver_pod + + driver_pod, spark_obj_spec = launcher.start_spark_job( + image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds ) - return launcher + return driver_pod def execute(self, context: Context): self.name = self.create_job_name() + self._setup_spark_configuration(context) + + if self.deferrable: + self.execute_async(context) + + return super().execute(context) + + def _setup_spark_configuration(self, context: Context): + """Set up Spark-specific configuration including reattach logic.""" + import copy + + template_body = copy.deepcopy(self.template_body) + + if self.reattach_on_restart: + task_context_labels = self._get_ti_pod_labels(context) + + existing_pod = self.find_spark_job(context) + if existing_pod: + self.log.info( + "Found existing Spark driver pod %s. Reattaching to it.", existing_pod.metadata.name + ) + self.pod = existing_pod + self.pod_request_obj = None + return + + if "spark" not in template_body: + template_body["spark"] = {} + if "spec" not in template_body["spark"]: + template_body["spark"]["spec"] = {} + + spec_dict = template_body["spark"]["spec"] + + if "labels" not in spec_dict: + spec_dict["labels"] = {} + spec_dict["labels"].update(task_context_labels) + + for component in ["driver", "executor"]: + if component not in spec_dict: + spec_dict[component] = {} + + if "labels" not in spec_dict[component]: + spec_dict[component]["labels"] = {} + + spec_dict[component]["labels"].update(task_context_labels) + self.log.info("Creating sparkApplication.") - self.pod = self.get_or_create_spark_crd(context) + self.launcher = CustomObjectLauncher( + name=self.name, + namespace=self.namespace, + kube_client=self.client, + custom_obj_api=self.custom_obj_api, + template_body=template_body, + ) + self.pod = self.get_or_create_spark_crd(self.launcher, context) self.pod_request_obj = self.launcher.pod_spec - return super().execute(context=context) + def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True): + """Override parent's find_pod to use our Spark-specific find_spark_job method.""" + return self.find_spark_job(context, exclude_checked=exclude_checked) def on_kill(self) -> None: if self.launcher: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py index bf5673d706b6b..2299a567b41f0 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -65,6 +65,144 @@ async def patch_pod_manager_methods(): mock.patch.stopall() +def _get_expected_k8s_dict(): + """Create expected K8S dict on-demand.""" + return { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"name": "default_yaml_template", "namespace": "default"}, + "spec": { + "type": "Python", + "mode": "cluster", + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainApplicationFile": "local:///opt/test.py", + "sparkVersion": "3.0.0", + "restartPolicy": {"type": "Never"}, + "successfulRunHistoryLimit": 1, + "pythonVersion": "3", + "volumes": [], + "labels": {}, + "imagePullSecrets": "", + "hadoopConf": {}, + "dynamicAllocation": { + "enabled": False, + "initialExecutors": 1, + "maxExecutors": 1, + "minExecutors": 1, + }, + "driver": { + "cores": 1, + "coreLimit": "1200m", + "memory": "365m", + "labels": {}, + "nodeSelector": {}, + "serviceAccount": "default", + "volumeMounts": [], + "env": [], + "envFrom": [], + "tolerations": [], + "affinity": {"nodeAffinity": {}, "podAffinity": {}, "podAntiAffinity": {}}, + }, + "executor": { + "cores": 1, + "instances": 1, + "memory": "365m", + "labels": {}, + "env": [], + "envFrom": [], + "nodeSelector": {}, + "volumeMounts": [], + "tolerations": [], + "affinity": {"nodeAffinity": {}, "podAffinity": {}, "podAntiAffinity": {}}, + }, + }, + } + + +def _get_expected_application_dict_with_labels(task_name="default_yaml"): + """Create expected application dict with task context labels on-demand.""" + task_context_labels = { + "dag_id": "dag", + "task_id": task_name, + "run_id": "manual__2016-01-01T0100000100-da4d1ce7b", + "spark_kubernetes_operator": "True", + "try_number": "0", + "version": "2.4.5", + } + + return { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"name": task_name, "namespace": "default"}, + "spec": { + "type": "Scala", + "mode": "cluster", + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainClass": "org.apache.spark.examples.SparkPi", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "sparkVersion": "2.4.5", + "restartPolicy": {"type": "Never"}, + "volumes": [{"name": "test-volume", "hostPath": {"path": "/tmp", "type": "Directory"}}], + "driver": { + "cores": 1, + "coreLimit": "1200m", + "memory": "512m", + "labels": task_context_labels.copy(), + "serviceAccount": "spark", + "volumeMounts": [{"name": "test-volume", "mountPath": "/tmp"}], + }, + "executor": { + "cores": 1, + "instances": 1, + "memory": "512m", + "labels": task_context_labels.copy(), + "volumeMounts": [{"name": "test-volume", "mountPath": "/tmp"}], + }, + }, + } + + +def _get_expected_application_dict_without_task_context_labels(task_name="default_yaml"): + """Create expected application dict without task context labels (only original file labels).""" + original_file_labels = { + "version": "2.4.5", + } + + return { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"name": task_name, "namespace": "default"}, + "spec": { + "type": "Scala", + "mode": "cluster", + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainClass": "org.apache.spark.examples.SparkPi", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "sparkVersion": "2.4.5", + "restartPolicy": {"type": "Never"}, + "volumes": [{"name": "test-volume", "hostPath": {"path": "/tmp", "type": "Directory"}}], + "driver": { + "cores": 1, + "coreLimit": "1200m", + "memory": "512m", + "labels": original_file_labels.copy(), + "serviceAccount": "spark", + "volumeMounts": [{"name": "test-volume", "mountPath": "/tmp"}], + }, + "executor": { + "cores": 1, + "instances": 1, + "memory": "512m", + "labels": original_file_labels.copy(), + "volumeMounts": [{"name": "test-volume", "mountPath": "/tmp"}], + }, + }, + } + + @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") def test_spark_kubernetes_operator(mock_kubernetes_hook, data_file): operator = SparkKubernetesOperator( @@ -114,86 +252,6 @@ def test_spark_kubernetes_operator_hook(mock_kubernetes_hook, data_file): ) -TEST_K8S_DICT = { - "apiVersion": "sparkoperator.k8s.io/v1beta2", - "kind": "SparkApplication", - "metadata": {"name": "default_yaml_template", "namespace": "default"}, - "spec": { - "driver": { - "coreLimit": "1200m", - "cores": 1, - "labels": {}, - "memory": "365m", - "nodeSelector": {}, - "serviceAccount": "default", - "volumeMounts": [], - "env": [], - "envFrom": [], - "tolerations": [], - "affinity": {"nodeAffinity": {}, "podAffinity": {}, "podAntiAffinity": {}}, - }, - "executor": { - "cores": 1, - "instances": 1, - "labels": {}, - "env": [], - "envFrom": [], - "memory": "365m", - "nodeSelector": {}, - "volumeMounts": [], - "tolerations": [], - "affinity": {"nodeAffinity": {}, "podAffinity": {}, "podAntiAffinity": {}}, - }, - "hadoopConf": {}, - "dynamicAllocation": {"enabled": False, "initialExecutors": 1, "maxExecutors": 1, "minExecutors": 1}, - "image": "gcr.io/spark-operator/spark:v2.4.5", - "imagePullPolicy": "Always", - "mainApplicationFile": "local:///opt/test.py", - "mode": "cluster", - "restartPolicy": {"type": "Never"}, - "sparkVersion": "3.0.0", - "successfulRunHistoryLimit": 1, - "pythonVersion": "3", - "type": "Python", - "imagePullSecrets": "", - "labels": {}, - "volumes": [], - }, -} - -TEST_APPLICATION_DICT = { - "apiVersion": "sparkoperator.k8s.io/v1beta2", - "kind": "SparkApplication", - "metadata": {"name": "default_yaml", "namespace": "default"}, - "spec": { - "driver": { - "coreLimit": "1200m", - "cores": 1, - "labels": {"version": "2.4.5"}, - "memory": "512m", - "serviceAccount": "spark", - "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], - }, - "executor": { - "cores": 1, - "instances": 1, - "labels": {"version": "2.4.5"}, - "memory": "512m", - "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], - }, - "image": "gcr.io/spark-operator/spark:v2.4.5", - "imagePullPolicy": "Always", - "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", - "mainClass": "org.apache.spark.examples.SparkPi", - "mode": "cluster", - "restartPolicy": {"type": "Never"}, - "sparkVersion": "2.4.5", - "type": "Scala", - "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], - }, -} - - def create_context(task): dag = DAG(dag_id="dag", schedule=None) tzinfo = pendulum.timezone("Europe/Amsterdam") @@ -269,6 +327,7 @@ def execute_operator( application_file=application_file, template_spec=job_spec, kubernetes_conn_id="kubernetes_default_kube_config", + reattach_on_restart=False, # Disable reattach for application creation tests ) context = create_context(op) op.execute(context) @@ -317,9 +376,10 @@ def test_create_application( assert isinstance(done_op.name, str) assert done_op.name != "" - TEST_APPLICATION_DICT["metadata"]["name"] = done_op.name + expected_dict = _get_expected_application_dict_without_task_context_labels(task_name) + expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, + body=expected_dict, **self.call_commons, ) @@ -362,9 +422,10 @@ def test_create_application_and_use_name_from_operator_args( else: assert done_op.name == name_normalized - TEST_APPLICATION_DICT["metadata"]["name"] = done_op.name + expected_dict = _get_expected_application_dict_without_task_context_labels(task_name) + expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, + body=expected_dict, **self.call_commons, ) @@ -404,9 +465,10 @@ def test_create_application_and_use_name_task_id( else: assert done_op.name == name_normalized - TEST_APPLICATION_DICT["metadata"]["name"] = done_op.name + expected_dict = _get_expected_application_dict_without_task_context_labels(task_name) + expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, + body=expected_dict, **self.call_commons, ) @@ -438,9 +500,10 @@ def test_new_template_from_yaml( else: assert done_op.name == name_normalized - TEST_K8S_DICT["metadata"]["name"] = done_op.name + expected_dict = _get_expected_k8s_dict() + expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( - body=TEST_K8S_DICT, + body=expected_dict, **self.call_commons, ) @@ -473,9 +536,10 @@ def test_template_spec( else: assert done_op.name == name_normalized - TEST_K8S_DICT["metadata"]["name"] = done_op.name + expected_dict = _get_expected_k8s_dict() + expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( - body=TEST_K8S_DICT, + body=expected_dict, **self.call_commons, ) @@ -488,6 +552,12 @@ def test_template_spec( @patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object_status") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") +@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute", return_value=None) +@patch( + "airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.is_in_cluster", + new_callable=mock.PropertyMock, + return_value=False, +) class TestSparkKubernetesOperator: @pytest.fixture(autouse=True) def setup_connections(self, create_connection_without_db): @@ -504,21 +574,27 @@ def setup_connections(self, create_connection_without_db): args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} self.dag = DAG("test_dag_id", schedule=None, default_args=args) - def execute_operator(self, task_name, mock_create_job_name, job_spec): + def execute_operator(self, task_name, mock_create_job_name, job_spec, mock_get_kube_client=None): mock_create_job_name.return_value = task_name + + if mock_get_kube_client: + mock_get_kube_client.list_namespaced_pod.return_value.items = [] + op = SparkKubernetesOperator( template_spec=job_spec, kubernetes_conn_id="kubernetes_default_kube_config", task_id=task_name, get_logs=True, + reattach_on_restart=False, # Disable reattach for basic tests ) context = create_context(op) op.execute(context) return op - @pytest.mark.asyncio def test_env( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -534,18 +610,18 @@ def test_env( # test env vars job_spec["kubernetes"]["env_vars"] = {"TEST_ENV_1": "VALUE1"} - # test env from env_from = [ k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name="env-direct-configmap")), k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name="env-direct-secret")), ] job_spec["kubernetes"]["env_from"] = copy.deepcopy(env_from) - # test from_env_config_map job_spec["kubernetes"]["from_env_config_map"] = ["env-from-configmap"] job_spec["kubernetes"]["from_env_secret"] = ["env-from-secret"] - op = self.execute_operator(task_name, mock_create_job_name, job_spec=job_spec) + op = self.execute_operator( + task_name, mock_create_job_name, job_spec=job_spec, mock_get_kube_client=mock_get_kube_client + ) assert op.launcher.body["spec"]["driver"]["env"] == [ k8s.V1EnvVar(name="TEST_ENV_1", value="VALUE1"), ] @@ -563,6 +639,8 @@ def test_env( @pytest.mark.asyncio def test_volume( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -609,6 +687,8 @@ def test_volume( @pytest.mark.asyncio def test_pull_secret( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -630,6 +710,8 @@ def test_pull_secret( @pytest.mark.asyncio def test_affinity( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -684,6 +766,8 @@ def test_affinity( @pytest.mark.asyncio def test_toleration( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -711,6 +795,8 @@ def test_toleration( @pytest.mark.asyncio def test_get_logs_from_driver( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -723,10 +809,23 @@ def test_get_logs_from_driver( ): task_name = "test_get_logs_from_driver" job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) - op = self.execute_operator(task_name, mock_create_job_name, job_spec=job_spec) + + def mock_parent_execute_side_effect(context): + mock_fetch_requested_container_logs( + pod=mock_create_pod.return_value, + containers="spark-kubernetes-driver", + follow_logs=True, + container_name_log_prefix_enabled=True, + log_formatter=None, + ) + return None + + mock_parent_execute.side_effect = mock_parent_execute_side_effect + + self.execute_operator(task_name, mock_create_job_name, job_spec=job_spec) mock_fetch_requested_container_logs.assert_called_once_with( - pod=op.pod, + pod=mock_create_pod.return_value, containers="spark-kubernetes-driver", follow_logs=True, container_name_log_prefix_enabled=True, @@ -736,6 +835,8 @@ def test_get_logs_from_driver( @pytest.mark.asyncio def test_find_custom_pod_labels( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, @@ -762,9 +863,91 @@ def test_find_custom_pod_labels( op.find_spark_job(context) mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + @patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook") + def test_adds_task_context_labels_to_driver_and_executor( + self, + mock_kubernetes_hook, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + task_name = "test_adds_task_context_labels" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + op.execute(context) + + task_context_labels = op._get_ti_pod_labels(context) + + # Check that labels were added to the template body structure + created_body = mock_create_namespaced_crd.call_args[1]["body"] + for component in ["driver", "executor"]: + for label_key, label_value in task_context_labels.items(): + assert label_key in created_body["spec"][component]["labels"] + assert created_body["spec"][component]["labels"][label_key] == label_value + + def test_reattach_on_restart_with_task_context_labels( + self, + mock_is_in_cluster, + mock_parent_execute, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + task_name = "test_reattach_on_restart" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + reattach_on_restart=True, + ) + context = create_context(op) + + mock_pod = mock.MagicMock() + mock_pod.metadata.name = f"{task_name}-driver" + mock_pod.metadata.labels = op._get_ti_pod_labels(context) + mock_pod.metadata.labels["spark-role"] = "driver" + mock_pod.metadata.labels["try_number"] = str(context["ti"].try_number) + mock_get_kube_client.list_namespaced_pod.return_value.items = [mock_pod] + + op.execute(context) + + label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver" + mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + + mock_create_namespaced_crd.assert_not_called() + @pytest.mark.asyncio def test_execute_deferrable( self, + mock_is_in_cluster, + mock_parent_execute, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup,