Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions providers/cncf/kubernetes/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ Changelog
---------

.. warning::
``KubernetesJobOperator`` no longer supports setting ``parallelism = 0``.
``KubernetesJobOperator`` no longer supports setting ``parallelism = 0`` with ``wait_until_job_complete=True``.
Previously this would create a job that would never complete and always fail the task.
Executing a task with ``parallelism = 0`` will now raise a validation error.
Executing a task with ``parallelism = 0`` and ``wait_until_job_complete=True`` will now raise a validation error.

10.12.0
.......
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def execute(self, context: Context):
stacklevel=2,
)
self.parallelism = 1
elif self.parallelism < 1:
raise AirflowException("parallelism cannot be less than 1.")
elif self.wait_until_job_complete and self.parallelism < 1:
# get_pods() will raise an error if parallelism = 0
raise AirflowException("parallelism cannot be less than 1 with `wait_until_job_complete=True`.")
self.job_request_obj = self.build_job_request_obj(context)
self.job = self.create_job( # must set `self.job` for `on_kill`
job_request_obj=self.job_request_obj
Expand All @@ -218,13 +219,15 @@ def execute(self, context: Context):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)

self.pods: Sequence[k8s.V1Pod] = self.get_pods(pod_request_obj=self.pod_request_obj, context=context)
if self.wait_until_job_complete:
self.pods: Sequence[k8s.V1Pod] = self.get_pods(
pod_request_obj=self.pod_request_obj, context=context
)

if self.wait_until_job_complete and self.deferrable:
self.execute_deferrable()
return
if self.deferrable:
self.execute_deferrable()
return

if self.wait_until_job_complete:
if self.do_xcom_push:
xcom_result = []
for pod in self.pods:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,11 @@ def test_task_id_as_name_dag_id_is_ignored(self):

@pytest.mark.parametrize("randomize", [True, False])
@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_name_normalized_on_execution(
self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods, randomize
self, mock_hook, mock_create_job, mock_build_job_request_obj, randomize
):
"""Test that names with underscores are normalized to hyphens on execution."""
name_base = "test_extra-123"
Expand Down Expand Up @@ -545,9 +544,7 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m
mock_ti = mock.MagicMock()
context = dict(ti=mock_ti)

op = KubernetesJobOperator(
task_id="test_task_id",
)
op = KubernetesJobOperator(task_id="test_task_id", wait_until_job_complete=False)
execute_result = op.execute(context=context)

mock_build_job_request_obj.assert_called_once_with(context)
Expand All @@ -572,20 +569,19 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_execute_with_parallelism(
self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods
self,
mock_hook,
mock_create_job,
mock_build_job_request_obj,
mock_get_pods,
):
mock_hook.return_value.is_job_failed.return_value = False
mock_job_request_obj = mock_build_job_request_obj.return_value
mock_job_expected = mock_create_job.return_value
mock_get_pods.return_value = [mock.MagicMock(), mock.MagicMock()]
mock_pods_expected = mock_get_pods.return_value
mock_ti = mock.MagicMock()
context = dict(ti=mock_ti)

op = KubernetesJobOperator(
task_id="test_task_id",
parallelism=2,
)
op = KubernetesJobOperator(task_id="test_task_id", parallelism=2, wait_until_job_complete=False)
execute_result = op.execute(context=context)

mock_build_job_request_obj.assert_called_once_with(context)
Expand All @@ -600,9 +596,7 @@ def test_execute_with_parallelism(

assert op.job_request_obj == mock_job_request_obj
assert op.job == mock_job_expected
assert op.pods == mock_pods_expected
with pytest.warns(AirflowProviderDeprecationWarning):
assert op.pod is mock_pods_expected[0]
mock_get_pods.assert_not_called()
assert not op.wait_until_job_complete
assert execute_result is None
assert not mock_hook.wait_until_job_complete.called
Expand Down Expand Up @@ -796,6 +790,8 @@ def test_wait_until_job_complete(
):
mock_job_expected = mock_create_job.return_value
mock_ti = mock.MagicMock()
mock_get_pods.return_value = [mock.MagicMock()]
mock_pods_expected = mock_get_pods.return_value

op = KubernetesJobOperator(
task_id="test_task_id", wait_until_job_complete=True, job_poll_interval=POLL_INTERVAL
Expand All @@ -809,6 +805,9 @@ def test_wait_until_job_complete(
namespace=mock_job_expected.metadata.namespace,
job_poll_interval=POLL_INTERVAL,
)
assert op.pods == mock_pods_expected
with pytest.raises(AirflowProviderDeprecationWarning):
assert op.pod == mock_pods_expected[0]

@pytest.mark.parametrize("do_xcom_push", [True, False])
@pytest.mark.parametrize("get_logs", [True, False])
Expand Down Expand Up @@ -1034,6 +1033,38 @@ def test_get_pods_retry(
assert result == return_value
assert mock_client.return_value.list_namespaced_pod.call_count == successful_try + 1

@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook"), new_callable=mock.PropertyMock)
def test_create_zero_parallelism_job(self, mock_hook, mock_build_job_request_obj, mock_get_pods):
mock_context = mock.MagicMock()
mock_job = mock.MagicMock()
mock_job.to_dict.return_value = {}
mock_hook.return_value.get_namespace.return_value = "fakenamespace"
mock_build_job_request_obj.return_value = mock_job
op = KubernetesJobOperator(task_id="faketask", parallelism=0, wait_until_job_complete=False)

op.execute(mock_context)

mock_hook.return_value.create_job.assert_called_once_with(job=mock_job)
mock_get_pods.assert_not_called()

@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook"), new_callable=mock.PropertyMock)
def test_create_zero_parallelism_fails_validation(
self, mock_hook, mock_build_job_request_obj, mock_get_pods
):
mock_context = mock.MagicMock()
op = KubernetesJobOperator(task_id="faketask", parallelism=0, wait_until_job_complete=True)

with pytest.raises(AirflowException):
op.execute(mock_context)

mock_build_job_request_obj.assert_not_called()
mock_hook.return_value.create_job.assert_not_called()
mock_get_pods.assert_not_called()


@pytest.mark.db_test
@pytest.mark.execution_timeout(300)
Expand Down