diff --git a/providers/cncf/kubernetes/docs/changelog.rst b/providers/cncf/kubernetes/docs/changelog.rst index 5f762a72259cf..94ea1b79e9d76 100644 --- a/providers/cncf/kubernetes/docs/changelog.rst +++ b/providers/cncf/kubernetes/docs/changelog.rst @@ -28,9 +28,9 @@ Changelog --------- .. warning:: - ``KubernetesJobOperator`` no longer supports setting ``parallelism = 0``. + ``KubernetesJobOperator`` no longer supports setting ``parallelism = 0`` with ``wait_until_job_complete=True``. Previously this would create a job that would never complete and always fail the task. - Executing a task with ``parallelism = 0`` will now raise a validation error. + Executing a task with ``parallelism = 0`` and ``wait_until_job_complete=True`` will now raise a validation error. 10.12.0 ....... diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py index f7e9233868e62..52c43fd1130ac 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py @@ -207,8 +207,9 @@ def execute(self, context: Context): stacklevel=2, ) self.parallelism = 1 - elif self.parallelism < 1: - raise AirflowException("parallelism cannot be less than 1.") + elif self.wait_until_job_complete and self.parallelism < 1: + # get_pods() will raise an error if parallelism = 0 + raise AirflowException("parallelism cannot be less than 1 with `wait_until_job_complete=True`.") self.job_request_obj = self.build_job_request_obj(context) self.job = self.create_job( # must set `self.job` for `on_kill` job_request_obj=self.job_request_obj @@ -218,13 +219,15 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) - self.pods: Sequence[k8s.V1Pod] = self.get_pods(pod_request_obj=self.pod_request_obj, context=context) + if self.wait_until_job_complete: + self.pods: Sequence[k8s.V1Pod] = self.get_pods( + pod_request_obj=self.pod_request_obj, context=context + ) - if self.wait_until_job_complete and self.deferrable: - self.execute_deferrable() - return + if self.deferrable: + self.execute_deferrable() + return - if self.wait_until_job_complete: if self.do_xcom_push: xcom_result = [] for pod in self.pods: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py index 0a0d91b9ccf28..ab0b7943a3b57 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py @@ -500,12 +500,11 @@ def test_task_id_as_name_dag_id_is_ignored(self): @pytest.mark.parametrize("randomize", [True, False]) @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) def test_name_normalized_on_execution( - self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods, randomize + self, mock_hook, mock_create_job, mock_build_job_request_obj, randomize ): """Test that names with underscores are normalized to hyphens on execution.""" name_base = "test_extra-123" @@ -545,9 +544,7 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m mock_ti = mock.MagicMock() context = dict(ti=mock_ti) - op = KubernetesJobOperator( - task_id="test_task_id", - ) + op = KubernetesJobOperator(task_id="test_task_id", wait_until_job_complete=False) execute_result = op.execute(context=context) mock_build_job_request_obj.assert_called_once_with(context) @@ -572,20 +569,19 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) def test_execute_with_parallelism( - self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods + self, + mock_hook, + mock_create_job, + mock_build_job_request_obj, + mock_get_pods, ): mock_hook.return_value.is_job_failed.return_value = False mock_job_request_obj = mock_build_job_request_obj.return_value mock_job_expected = mock_create_job.return_value - mock_get_pods.return_value = [mock.MagicMock(), mock.MagicMock()] - mock_pods_expected = mock_get_pods.return_value mock_ti = mock.MagicMock() context = dict(ti=mock_ti) - op = KubernetesJobOperator( - task_id="test_task_id", - parallelism=2, - ) + op = KubernetesJobOperator(task_id="test_task_id", parallelism=2, wait_until_job_complete=False) execute_result = op.execute(context=context) mock_build_job_request_obj.assert_called_once_with(context) @@ -600,9 +596,7 @@ def test_execute_with_parallelism( assert op.job_request_obj == mock_job_request_obj assert op.job == mock_job_expected - assert op.pods == mock_pods_expected - with pytest.warns(AirflowProviderDeprecationWarning): - assert op.pod is mock_pods_expected[0] + mock_get_pods.assert_not_called() assert not op.wait_until_job_complete assert execute_result is None assert not mock_hook.wait_until_job_complete.called @@ -796,6 +790,8 @@ def test_wait_until_job_complete( ): mock_job_expected = mock_create_job.return_value mock_ti = mock.MagicMock() + mock_get_pods.return_value = [mock.MagicMock()] + mock_pods_expected = mock_get_pods.return_value op = KubernetesJobOperator( task_id="test_task_id", wait_until_job_complete=True, job_poll_interval=POLL_INTERVAL @@ -809,6 +805,9 @@ def test_wait_until_job_complete( namespace=mock_job_expected.metadata.namespace, job_poll_interval=POLL_INTERVAL, ) + assert op.pods == mock_pods_expected + with pytest.raises(AirflowProviderDeprecationWarning): + assert op.pod == mock_pods_expected[0] @pytest.mark.parametrize("do_xcom_push", [True, False]) @pytest.mark.parametrize("get_logs", [True, False]) @@ -1034,6 +1033,38 @@ def test_get_pods_retry( assert result == return_value assert mock_client.return_value.list_namespaced_pod.call_count == successful_try + 1 + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook"), new_callable=mock.PropertyMock) + def test_create_zero_parallelism_job(self, mock_hook, mock_build_job_request_obj, mock_get_pods): + mock_context = mock.MagicMock() + mock_job = mock.MagicMock() + mock_job.to_dict.return_value = {} + mock_hook.return_value.get_namespace.return_value = "fakenamespace" + mock_build_job_request_obj.return_value = mock_job + op = KubernetesJobOperator(task_id="faketask", parallelism=0, wait_until_job_complete=False) + + op.execute(mock_context) + + mock_hook.return_value.create_job.assert_called_once_with(job=mock_job) + mock_get_pods.assert_not_called() + + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook"), new_callable=mock.PropertyMock) + def test_create_zero_parallelism_fails_validation( + self, mock_hook, mock_build_job_request_obj, mock_get_pods + ): + mock_context = mock.MagicMock() + op = KubernetesJobOperator(task_id="faketask", parallelism=0, wait_until_job_complete=True) + + with pytest.raises(AirflowException): + op.execute(mock_context) + + mock_build_job_request_obj.assert_not_called() + mock_hook.return_value.create_job.assert_not_called() + mock_get_pods.assert_not_called() + @pytest.mark.db_test @pytest.mark.execution_timeout(300)