diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py index a83bc7a3de387..c3a20af81be92 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py @@ -238,7 +238,9 @@ def pod_manager(self) -> PodManager: def get_body(self): self.body: dict = SparkJobSpec(**self.template_body["spark"]) - self.body.metadata = {"name": self.name, "namespace": self.namespace} + if not hasattr(self.body, "metadata") or not isinstance(self.body.metadata, dict): + self.body.metadata = {} + self.body.metadata.update({"name": self.name, "namespace": self.namespace}) if self.template_body.get("kubernetes"): k8s_spec: dict = KubernetesSpec(**self.template_body["kubernetes"]) self.body.spec["volumes"] = k8s_spec.volumes diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_custom_object_launcher.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_custom_object_launcher.py index eab85e5a87339..078547af56183 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_custom_object_launcher.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_custom_object_launcher.py @@ -213,6 +213,27 @@ def get_pod_status(self, reason: str, message: str | None = None): ] ) + def test_get_body_initializes_metadata_when_missing(self, mock_launcher): + mock_launcher.template_body["spark"].pop("metadata", None) + body = mock_launcher.get_body() + assert isinstance(body["metadata"], dict) + assert body["metadata"]["name"] == mock_launcher.name + assert body["metadata"]["namespace"] == mock_launcher.namespace + + def test_get_body_replaces_non_dict_metadata(self, mock_launcher): + mock_launcher.template_body["spark"]["metadata"] = "not-a-dict" + body = mock_launcher.get_body() + assert isinstance(body["metadata"], dict) + assert body["metadata"]["name"] == mock_launcher.name + assert body["metadata"]["namespace"] == mock_launcher.namespace + + def test_get_body_preserves_existing_metadata_labels(self, mock_launcher): + mock_launcher.template_body["spark"]["metadata"] = {"labels": {"team": "data"}} + body = mock_launcher.get_body() + assert body["metadata"]["labels"]["team"] == "data" + assert body["metadata"]["name"] == mock_launcher.name + assert body["metadata"]["namespace"] == mock_launcher.namespace + @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager") def test_start_spark_job_no_error(self, mock_pod_manager, mock_launcher): mock_launcher.start_spark_job()