diff --git a/providers/cncf/kubernetes/docs/changelog.rst b/providers/cncf/kubernetes/docs/changelog.rst index e897507453dc1..5f762a72259cf 100644 --- a/providers/cncf/kubernetes/docs/changelog.rst +++ b/providers/cncf/kubernetes/docs/changelog.rst @@ -27,6 +27,10 @@ Changelog --------- +.. warning:: + ``KubernetesJobOperator`` no longer supports setting ``parallelism = 0``. + Previously this would create a job that would never complete and always fail the task. + Executing a task with ``parallelism = 0`` will now raise a validation error. 10.12.0 ....... diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py index a470b6eac4fff..f7e9233868e62 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py @@ -81,6 +81,7 @@ class KubernetesJobOperator(KubernetesPodOperator): :param completions: Specifies the desired number of successfully finished pods the job should be run with. :param manual_selector: manualSelector controls generation of pod labels and pod selectors. :param parallelism: Specifies the maximum desired number of pods the job should run at any given time. + The value here must be >=1. Default value is 1 :param selector: The selector of this V1JobSpec. :param suspend: Suspend specifies whether the Job controller should create Pods or not. :param ttl_seconds_after_finished: ttlSecondsAfterFinished limits the lifetime of a Job that has finished execution (either Complete or Failed). @@ -114,7 +115,7 @@ def __init__( completion_mode: str | None = None, completions: int | None = None, manual_selector: bool | None = None, - parallelism: int | None = None, + parallelism: int = 1, selector: k8s.V1LabelSelector | None = None, suspend: bool | None = None, ttl_seconds_after_finished: int | None = None, @@ -199,6 +200,15 @@ def execute(self, context: Context): "Getting Logs and pushing to XCom are available only with parameter `wait_until_job_complete=True`. " "Please, set it up." ) + if self.parallelism is None: + warnings.warn( + "parallelism should be set explicitly. Defaulting to 1.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.parallelism = 1 + elif self.parallelism < 1: + raise AirflowException("parallelism cannot be less than 1.") self.job_request_obj = self.build_job_request_obj(context) self.job = self.create_job( # must set `self.job` for `on_kill` job_request_obj=self.job_request_obj @@ -208,16 +218,7 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) - self.pods: Sequence[k8s.V1Pod] | None = None - if self.parallelism is None and self.pod is None: - self.pods = [ - self.get_or_create_pod( - pod_request_obj=self.pod_request_obj, - context=context, - ) - ] - else: - self.pods = self.get_pods(pod_request_obj=self.pod_request_obj, context=context) + self.pods: Sequence[k8s.V1Pod] = self.get_pods(pod_request_obj=self.pod_request_obj, context=context) if self.wait_until_job_complete and self.deferrable: self.execute_deferrable() @@ -461,7 +462,9 @@ def get_pods( pod_list: Sequence[k8s.V1Pod] = [] retry_number: int = 0 - while len(pod_list) != self.parallelism or retry_number <= self.discover_pods_retry_number: + while retry_number <= self.discover_pods_retry_number: + if len(pod_list) == self.parallelism: + break pod_list = self.client.list_namespaced_pod( namespace=pod_request_obj.metadata.namespace, label_selector=label_selector, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py index 4c8289ed42b0e..0a0d91b9ccf28 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py @@ -500,12 +500,12 @@ def test_task_id_as_name_dag_id_is_ignored(self): @pytest.mark.parametrize("randomize", [True, False]) @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) def test_name_normalized_on_execution( - self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod, randomize + self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods, randomize ): """Test that names with underscores are normalized to hyphens on execution.""" name_base = "test_extra-123" @@ -525,8 +525,7 @@ def test_name_normalized_on_execution( task_id="task", ) - with pytest.warns(AirflowProviderDeprecationWarning): - op.execute(context=context) + op.execute(context=context) # Verify the name was normalized (underscore replaced with hyphen) if randomize: @@ -535,11 +534,11 @@ def test_name_normalized_on_execution( assert op.name == normalized_name @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) - def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod): + def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods): mock_hook.return_value.is_job_failed.return_value = False mock_job_request_obj = mock_build_job_request_obj.return_value mock_job_expected = mock_create_job.return_value @@ -549,8 +548,7 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m op = KubernetesJobOperator( task_id="test_task_id", ) - with pytest.warns(AirflowProviderDeprecationWarning): - execute_result = op.execute(context=context) + execute_result = op.execute(context=context) mock_build_job_request_obj.assert_called_once_with(context) mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj) @@ -610,7 +608,7 @@ def test_execute_with_parallelism( assert not mock_hook.wait_until_job_complete.called @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.execute_deferrable")) @@ -621,7 +619,7 @@ def test_execute_in_deferrable( mock_execute_deferrable, mock_create_job, mock_build_job_request_obj, - mock_get_or_create_pod, + mock_get_pods, ): mock_hook.return_value.is_job_failed.return_value = False mock_job_request_obj = mock_build_job_request_obj.return_value @@ -634,8 +632,7 @@ def test_execute_in_deferrable( wait_until_job_complete=True, deferrable=True, ) - with pytest.warns(AirflowProviderDeprecationWarning): - actual_result = op.execute(context=context) + actual_result = op.execute(context=context) mock_build_job_request_obj.assert_called_once_with(context) mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj) @@ -653,13 +650,11 @@ def test_execute_in_deferrable( assert not mock_hook.wait_until_job_complete.called @pytest.mark.non_db_test_override - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) - def test_execute_fail( - self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod - ): + def test_execute_fail(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods): mock_hook.return_value.is_job_failed.return_value = "Error" op = KubernetesJobOperator( @@ -667,9 +662,8 @@ def test_execute_fail( wait_until_job_complete=True, ) - with pytest.warns(AirflowProviderDeprecationWarning): - with pytest.raises(AirflowException): - op.execute(context=dict(ti=mock.MagicMock())) + with pytest.raises(AirflowException): + op.execute(context=dict(ti=mock.MagicMock())) @pytest.mark.non_db_test_override @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer")) @@ -789,7 +783,7 @@ def test_execute_deferrable_with_parallelism(self, mock_trigger, mock_execute_de ) assert actual_result is None - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(f"{HOOK_CLASS}.wait_until_job_complete") @@ -798,7 +792,7 @@ def test_wait_until_job_complete( mock_wait_until_job_complete, mock_create_job, mock_build_job_request_obj, - mock_get_or_create_pod, + mock_get_pods, ): mock_job_expected = mock_create_job.return_value mock_ti = mock.MagicMock() @@ -806,8 +800,7 @@ def test_wait_until_job_complete( op = KubernetesJobOperator( task_id="test_task_id", wait_until_job_complete=True, job_poll_interval=POLL_INTERVAL ) - with pytest.warns(AirflowProviderDeprecationWarning): - op.execute(context=dict(ti=mock_ti)) + op.execute(context=dict(ti=mock_ti)) assert op.wait_until_job_complete assert op.job_poll_interval == POLL_INTERVAL @@ -913,12 +906,11 @@ def test_on_kill_none_job(self, mock_hook, mock_client): mock_client.delete_namespaced_job.assert_not_called() mock_serialize.assert_not_called() - @pytest.mark.parametrize("parallelism", [None, 2]) + @pytest.mark.parametrize("parallelism", [1, 2]) @pytest.mark.parametrize("do_xcom_push", [True, False]) @pytest.mark.parametrize("get_logs", [True, False]) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.extract_xcom")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) - @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs") @@ -933,17 +925,18 @@ def test_execute_xcom_and_logs( mocked_fetch_logs, mock_create_job, mock_build_job_request_obj, - mock_get_or_create_pod, mock_get_pods, mock_extract_xcom, get_logs, do_xcom_push, parallelism, ): + mock_pod_1 = mock.MagicMock() if parallelism == 2: - mock_pod_1 = mock.MagicMock() mock_pod_2 = mock.MagicMock() mock_get_pods.return_value = [mock_pod_1, mock_pod_2] + else: + mock_get_pods.return_value = [mock_pod_1] mock_ti = mock.MagicMock() op = KubernetesJobOperator( task_id="test_task_id", @@ -954,11 +947,7 @@ def test_execute_xcom_and_logs( parallelism=parallelism, ) - if not parallelism: - with pytest.warns(AirflowProviderDeprecationWarning): - op.execute(context=dict(ti=mock_ti)) - else: - op.execute(context=dict(ti=mock_ti)) + op.execute(context=dict(ti=mock_ti)) if do_xcom_push and not parallelism: mock_extract_xcom.assert_called_once() @@ -974,6 +963,77 @@ def test_execute_xcom_and_logs( else: mocked_fetch_logs.assert_not_called() + @pytest.mark.parametrize("retries", [3, 0]) + @pytest.mark.parametrize("parallelism", [1, 2]) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.log_matching_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._build_find_pod_label_selector")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client"), new_callable=mock.PropertyMock) + def test_get_pods( + self, mock_client, mock_build_find_pod_label_selector, mock_log_matching_pod, parallelism, retries + ): + mock_context = mock.MagicMock() + mock_build_find_pod_label_selector.return_value = { + "dag_id": "fakedag", + "task_id": "faketask", + "run_id": "fakerun", + "kubernetes_pod_operator": "True", + } + mock_pod_1 = mock.MagicMock() + return_value = [mock_pod_1] + if parallelism == 2: + mock_pod_2 = mock.MagicMock() + return_value.append(mock_pod_2) + side_effects = [] + for _i in range(retries): + side_effects.append(k8s.V1PodList(items=[])) + side_effects.append(k8s.V1PodList(items=return_value)) + mock_client.return_value.list_namespaced_pod.side_effect = side_effects + + op = KubernetesJobOperator( + task_id="faketask", parallelism=parallelism, discover_pods_retry_number=retries + ) + + result = op.get_pods(mock.MagicMock(), mock_context) + + assert result == return_value + + for pod in return_value: + mock_log_matching_pod.assert_any_call(pod=pod, context=mock_context) + + @pytest.mark.parametrize("successful_try", [3, 1, 0]) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.log_matching_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._build_find_pod_label_selector")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client"), new_callable=mock.PropertyMock) + def test_get_pods_retry( + self, mock_client, mock_build_find_pod_label_selector, mock_log_matching_pod, successful_try + ): + retries = 3 + mock_context = mock.MagicMock() + mock_build_find_pod_label_selector.return_value = { + "dag_id": "fakedag", + "task_id": "faketask", + "run_id": "fakerun", + "kubernetes_pod_operator": "True", + } + + mock_pod_1 = mock.MagicMock() + return_value = [mock_pod_1] + + side_effects = [] + for i in range(retries + 1): + items = [] + if i == successful_try: + items.append(mock_pod_1) + side_effects.append(k8s.V1PodList(items=items)) + mock_client.return_value.list_namespaced_pod.side_effect = side_effects + + op = KubernetesJobOperator(task_id="faketask", parallelism=1, discover_pods_retry_number=retries) + + result = op.get_pods(mock.MagicMock(), mock_context) + + assert result == return_value + assert mock_client.return_value.list_namespaced_pod.call_count == successful_try + 1 + @pytest.mark.db_test @pytest.mark.execution_timeout(300)