Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions providers/cncf/kubernetes/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
Changelog
---------

.. warning::
``KubernetesJobOperator`` no longer supports setting ``parallelism = 0``.
Previously this would create a job that would never complete and always fail the task.
Executing a task with ``parallelism = 0`` will now raise a validation error.

10.12.0
.......
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class KubernetesJobOperator(KubernetesPodOperator):
:param completions: Specifies the desired number of successfully finished pods the job should be run with.
:param manual_selector: manualSelector controls generation of pod labels and pod selectors.
:param parallelism: Specifies the maximum desired number of pods the job should run at any given time.
The value here must be >=1. Default value is 1
:param selector: The selector of this V1JobSpec.
:param suspend: Suspend specifies whether the Job controller should create Pods or not.
:param ttl_seconds_after_finished: ttlSecondsAfterFinished limits the lifetime of a Job that has finished execution (either Complete or Failed).
Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
completion_mode: str | None = None,
completions: int | None = None,
manual_selector: bool | None = None,
parallelism: int | None = None,
parallelism: int = 1,
selector: k8s.V1LabelSelector | None = None,
suspend: bool | None = None,
ttl_seconds_after_finished: int | None = None,
Expand Down Expand Up @@ -199,6 +200,15 @@ def execute(self, context: Context):
"Getting Logs and pushing to XCom are available only with parameter `wait_until_job_complete=True`. "
"Please, set it up."
)
if self.parallelism is None:
warnings.warn(
"parallelism should be set explicitly. Defaulting to 1.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.parallelism = 1
elif self.parallelism < 1:
raise AirflowException("parallelism cannot be less than 1.")
self.job_request_obj = self.build_job_request_obj(context)
self.job = self.create_job( # must set `self.job` for `on_kill`
job_request_obj=self.job_request_obj
Expand All @@ -208,16 +218,7 @@ def execute(self, context: Context):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)

self.pods: Sequence[k8s.V1Pod] | None = None
if self.parallelism is None and self.pod is None:
self.pods = [
self.get_or_create_pod(
pod_request_obj=self.pod_request_obj,
context=context,
)
]
else:
self.pods = self.get_pods(pod_request_obj=self.pod_request_obj, context=context)
self.pods: Sequence[k8s.V1Pod] = self.get_pods(pod_request_obj=self.pod_request_obj, context=context)

if self.wait_until_job_complete and self.deferrable:
self.execute_deferrable()
Expand Down Expand Up @@ -461,7 +462,9 @@ def get_pods(
pod_list: Sequence[k8s.V1Pod] = []
retry_number: int = 0

while len(pod_list) != self.parallelism or retry_number <= self.discover_pods_retry_number:
while retry_number <= self.discover_pods_retry_number:
if len(pod_list) == self.parallelism:
break
pod_list = self.client.list_namespaced_pod(
namespace=pod_request_obj.metadata.namespace,
label_selector=label_selector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,12 @@ def test_task_id_as_name_dag_id_is_ignored(self):

@pytest.mark.parametrize("randomize", [True, False])
@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_name_normalized_on_execution(
self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod, randomize
self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods, randomize
):
"""Test that names with underscores are normalized to hyphens on execution."""
name_base = "test_extra-123"
Expand All @@ -525,8 +525,7 @@ def test_name_normalized_on_execution(
task_id="task",
)

with pytest.warns(AirflowProviderDeprecationWarning):
op.execute(context=context)
op.execute(context=context)

# Verify the name was normalized (underscore replaced with hyphen)
if randomize:
Expand All @@ -535,11 +534,11 @@ def test_name_normalized_on_execution(
assert op.name == normalized_name

@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod):
def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods):
mock_hook.return_value.is_job_failed.return_value = False
mock_job_request_obj = mock_build_job_request_obj.return_value
mock_job_expected = mock_create_job.return_value
Expand All @@ -549,8 +548,7 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj, m
op = KubernetesJobOperator(
task_id="test_task_id",
)
with pytest.warns(AirflowProviderDeprecationWarning):
execute_result = op.execute(context=context)
execute_result = op.execute(context=context)

mock_build_job_request_obj.assert_called_once_with(context)
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
Expand Down Expand Up @@ -610,7 +608,7 @@ def test_execute_with_parallelism(
assert not mock_hook.wait_until_job_complete.called

@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.execute_deferrable"))
Expand All @@ -621,7 +619,7 @@ def test_execute_in_deferrable(
mock_execute_deferrable,
mock_create_job,
mock_build_job_request_obj,
mock_get_or_create_pod,
mock_get_pods,
):
mock_hook.return_value.is_job_failed.return_value = False
mock_job_request_obj = mock_build_job_request_obj.return_value
Expand All @@ -634,8 +632,7 @@ def test_execute_in_deferrable(
wait_until_job_complete=True,
deferrable=True,
)
with pytest.warns(AirflowProviderDeprecationWarning):
actual_result = op.execute(context=context)
actual_result = op.execute(context=context)

mock_build_job_request_obj.assert_called_once_with(context)
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
Expand All @@ -653,23 +650,20 @@ def test_execute_in_deferrable(
assert not mock_hook.wait_until_job_complete.called

@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(HOOK_CLASS)
def test_execute_fail(
self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_or_create_pod
):
def test_execute_fail(self, mock_hook, mock_create_job, mock_build_job_request_obj, mock_get_pods):
mock_hook.return_value.is_job_failed.return_value = "Error"

op = KubernetesJobOperator(
task_id="test_task_id",
wait_until_job_complete=True,
)

with pytest.warns(AirflowProviderDeprecationWarning):
with pytest.raises(AirflowException):
op.execute(context=dict(ti=mock.MagicMock()))
with pytest.raises(AirflowException):
op.execute(context=dict(ti=mock.MagicMock()))

@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer"))
Expand Down Expand Up @@ -789,7 +783,7 @@ def test_execute_deferrable_with_parallelism(self, mock_trigger, mock_execute_de
)
assert actual_result is None

@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(f"{HOOK_CLASS}.wait_until_job_complete")
Expand All @@ -798,16 +792,15 @@ def test_wait_until_job_complete(
mock_wait_until_job_complete,
mock_create_job,
mock_build_job_request_obj,
mock_get_or_create_pod,
mock_get_pods,
):
mock_job_expected = mock_create_job.return_value
mock_ti = mock.MagicMock()

op = KubernetesJobOperator(
task_id="test_task_id", wait_until_job_complete=True, job_poll_interval=POLL_INTERVAL
)
with pytest.warns(AirflowProviderDeprecationWarning):
op.execute(context=dict(ti=mock_ti))
op.execute(context=dict(ti=mock_ti))

assert op.wait_until_job_complete
assert op.job_poll_interval == POLL_INTERVAL
Expand Down Expand Up @@ -913,12 +906,11 @@ def test_on_kill_none_job(self, mock_hook, mock_client):
mock_client.delete_namespaced_job.assert_not_called()
mock_serialize.assert_not_called()

@pytest.mark.parametrize("parallelism", [None, 2])
@pytest.mark.parametrize("parallelism", [1, 2])
@pytest.mark.parametrize("do_xcom_push", [True, False])
@pytest.mark.parametrize("get_logs", [True, False])
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.extract_xcom"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs")
Expand All @@ -933,17 +925,18 @@ def test_execute_xcom_and_logs(
mocked_fetch_logs,
mock_create_job,
mock_build_job_request_obj,
mock_get_or_create_pod,
mock_get_pods,
mock_extract_xcom,
get_logs,
do_xcom_push,
parallelism,
):
mock_pod_1 = mock.MagicMock()
if parallelism == 2:
mock_pod_1 = mock.MagicMock()
mock_pod_2 = mock.MagicMock()
mock_get_pods.return_value = [mock_pod_1, mock_pod_2]
else:
mock_get_pods.return_value = [mock_pod_1]
mock_ti = mock.MagicMock()
op = KubernetesJobOperator(
task_id="test_task_id",
Expand All @@ -954,11 +947,7 @@ def test_execute_xcom_and_logs(
parallelism=parallelism,
)

if not parallelism:
with pytest.warns(AirflowProviderDeprecationWarning):
op.execute(context=dict(ti=mock_ti))
else:
op.execute(context=dict(ti=mock_ti))
op.execute(context=dict(ti=mock_ti))

if do_xcom_push and not parallelism:
mock_extract_xcom.assert_called_once()
Expand All @@ -974,6 +963,77 @@ def test_execute_xcom_and_logs(
else:
mocked_fetch_logs.assert_not_called()

@pytest.mark.parametrize("retries", [3, 0])
@pytest.mark.parametrize("parallelism", [1, 2])
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.log_matching_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._build_find_pod_label_selector"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client"), new_callable=mock.PropertyMock)
def test_get_pods(
self, mock_client, mock_build_find_pod_label_selector, mock_log_matching_pod, parallelism, retries
):
mock_context = mock.MagicMock()
mock_build_find_pod_label_selector.return_value = {
"dag_id": "fakedag",
"task_id": "faketask",
"run_id": "fakerun",
"kubernetes_pod_operator": "True",
}
mock_pod_1 = mock.MagicMock()
return_value = [mock_pod_1]
if parallelism == 2:
mock_pod_2 = mock.MagicMock()
return_value.append(mock_pod_2)
side_effects = []
for _i in range(retries):
side_effects.append(k8s.V1PodList(items=[]))
side_effects.append(k8s.V1PodList(items=return_value))
mock_client.return_value.list_namespaced_pod.side_effect = side_effects

op = KubernetesJobOperator(
task_id="faketask", parallelism=parallelism, discover_pods_retry_number=retries
)

result = op.get_pods(mock.MagicMock(), mock_context)

assert result == return_value

for pod in return_value:
mock_log_matching_pod.assert_any_call(pod=pod, context=mock_context)

@pytest.mark.parametrize("successful_try", [3, 1, 0])
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.log_matching_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._build_find_pod_label_selector"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client"), new_callable=mock.PropertyMock)
def test_get_pods_retry(
self, mock_client, mock_build_find_pod_label_selector, mock_log_matching_pod, successful_try
):
retries = 3
mock_context = mock.MagicMock()
mock_build_find_pod_label_selector.return_value = {
"dag_id": "fakedag",
"task_id": "faketask",
"run_id": "fakerun",
"kubernetes_pod_operator": "True",
}

mock_pod_1 = mock.MagicMock()
return_value = [mock_pod_1]

side_effects = []
for i in range(retries + 1):
items = []
if i == successful_try:
items.append(mock_pod_1)
side_effects.append(k8s.V1PodList(items=items))
mock_client.return_value.list_namespaced_pod.side_effect = side_effects

op = KubernetesJobOperator(task_id="faketask", parallelism=1, discover_pods_retry_number=retries)

result = op.get_pods(mock.MagicMock(), mock_context)

assert result == return_value
assert mock_client.return_value.list_namespaced_pod.call_count == successful_try + 1


@pytest.mark.db_test
@pytest.mark.execution_timeout(300)
Expand Down