diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 124685e792e7f..49940144b5ddc 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -636,15 +636,11 @@ def invoke_defer_method(self): def execute_complete(self, context: Context, event: dict, **kwargs): pod = None - remote_pod = None try: pod = self.hook.get_pod( event["name"], event["namespace"], ) - # It is done to coincide with the current implementation of the general logic of the cleanup - # method. If it's going to be remade in future then it must be changed - remote_pod = pod if event["status"] in ("error", "failed", "timeout"): # fetch some logs when pod is failed if self.get_logs: @@ -661,16 +657,13 @@ def execute_complete(self, context: Context, event: dict, **kwargs): if self.do_xcom_push: xcom_sidecar_output = self.extract_xcom(pod=pod) - pod = self.pod_manager.await_pod_completion(pod) - # It is done to coincide with the current implementation of the general logic of - # the cleanup method. If it's going to be remade in future then it must be changed - remote_pod = pod return xcom_sidecar_output finally: - if pod is not None and remote_pod is not None: + pod = self.pod_manager.await_pod_completion(pod) + if pod is not None: self.post_complete_action( pod=pod, - remote_pod=remote_pod, + remote_pod=pod, ) def write_logs(self, pod: k8s.V1Pod): diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index 6fdf763eceda7..6443dfb63f3cc 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -154,23 +154,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.debug("Container %s status: %s", self.base_container_name, container_state) if container_state == ContainerState.TERMINATED: - if pod_status not in PodPhase.terminal_states: - self.log.info( - "Pod %s is still running. Sleeping for %s seconds.", - self.pod_name, - self.poll_interval, - ) - await asyncio.sleep(self.poll_interval) - else: - yield TriggerEvent( - { - "name": self.pod_name, - "namespace": self.pod_namespace, - "status": "success", - "message": "All containers inside pod have started successfully.", - } - ) - return + yield TriggerEvent( + { + "name": self.pod_name, + "namespace": self.pod_namespace, + "status": "success", + "message": "All containers inside pod have started successfully.", + } + ) + return elif self.should_wait(pod_phase=pod_status, container_state=container_state): self.log.info("Container is not completed and still working.") diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index 1ce25190ba4c6..15ca6553cd107 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -1389,10 +1389,12 @@ def test_async_create_pod_should_throw_exception(self, mocked_hook, mocked_clean ({"skip_on_exit_code": None}, 100, AirflowException, "Failed", "error"), ], ) + @patch(KUB_OP_PATH.format("pod_manager")) @patch(HOOK_CLASS) def test_async_create_pod_with_skip_on_exit_code_should_skip( self, mocked_hook, + mock_manager, extra_kwargs, actual_exit_code, expected_exc, @@ -1426,6 +1428,7 @@ def test_async_create_pod_with_skip_on_exit_code_should_skip( remote_pod.status.phase = pod_status remote_pod.status.container_statuses = [base_container, sidecar_container] mocked_hook.return_value.get_pod.return_value = remote_pod + mock_manager.await_pod_completion.return_value = remote_pod context = { "ti": MagicMock(), @@ -1608,3 +1611,99 @@ def test_cleanup_log_pod_spec_on_failure(self, log_pod_spec_on_failure, expect_m pod.status = V1PodStatus(phase=PodPhase.FAILED) with pytest.raises(AirflowException, match=expect_match): k.cleanup(pod, pod) + + +@pytest.mark.parametrize("do_xcom_push", [True, False]) +@patch(KUB_OP_PATH.format("extract_xcom")) +@patch(KUB_OP_PATH.format("post_complete_action")) +@patch(HOOK_CLASS) +def test_async_kpo_wait_termination_before_cleanup_on_success( + mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push +): + metadata = {"metadata.name": TEST_NAME, "metadata.namespace": TEST_NAMESPACE} + running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"}) + succeeded_state = mock.MagicMock(**metadata, **{"status.phase": "Succeeded"}) + mocked_hook.return_value.get_pod.return_value = running_state + read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod + read_pod_mock.side_effect = [ + running_state, + running_state, + succeeded_state, + ] + + ti_mock = MagicMock() + + success_event = { + "status": "success", + "message": TEST_SUCCESS_MESSAGE, + "name": TEST_NAME, + "namespace": TEST_NAMESPACE, + } + + k = KubernetesPodOperator(task_id="task", deferrable=True, do_xcom_push=do_xcom_push) + k.execute_complete({"ti": ti_mock}, success_event) + + # check if it gets the pod + mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, TEST_NAMESPACE) + + # check if it pushes the xcom + assert ti_mock.xcom_push.call_count == 2 + ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME) + ti_mock.xcom_push.assert_any_call(key="pod_namespace", value=TEST_NAMESPACE) + + # assert that the xcom are extracted/not extracted + if do_xcom_push: + mock_extract_xcom.assert_called_once() + else: + mock_extract_xcom.assert_not_called() + + # check if it waits for the pod to complete + assert read_pod_mock.call_count == 3 + + # assert that the cleanup is called + post_complete_action.assert_called_once() + + +@pytest.mark.parametrize("do_xcom_push", [True, False]) +@patch(KUB_OP_PATH.format("extract_xcom")) +@patch(KUB_OP_PATH.format("post_complete_action")) +@patch(HOOK_CLASS) +def test_async_kpo_wait_termination_before_cleanup_on_failure( + mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push +): + metadata = {"metadata.name": TEST_NAME, "metadata.namespace": TEST_NAMESPACE} + running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"}) + failed_state = mock.MagicMock(**metadata, **{"status.phase": "Failed"}) + mocked_hook.return_value.get_pod.return_value = running_state + read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod + read_pod_mock.side_effect = [ + running_state, + running_state, + failed_state, + ] + + ti_mock = MagicMock() + + success_event = {"status": "failed", "message": "error", "name": TEST_NAME, "namespace": TEST_NAMESPACE} + + post_complete_action.side_effect = AirflowException() + + k = KubernetesPodOperator(task_id="task", deferrable=True, do_xcom_push=do_xcom_push) + + with pytest.raises(AirflowException): + k.execute_complete({"ti": ti_mock}, success_event) + + # check if it gets the pod + mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, TEST_NAMESPACE) + + # assert that it does not push the xcom + ti_mock.xcom_push.assert_not_called() + + # assert that the xcom are not extracted + mock_extract_xcom.assert_not_called() + + # check if it waits for the pod to complete + assert read_pod_mock.call_count == 3 + + # assert that the cleanup is called + post_complete_action.assert_called_once() diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 4ed731b42578f..fbfff17278c7f 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -96,8 +96,7 @@ def test_serialize(self, trigger): @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}._get_async_hook") async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigger): - pod_mock = mock.MagicMock(**{"status.phase": "Succeeded"}) - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(pod_mock) + mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( @@ -112,35 +111,6 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg assert actual_event == expected_event - @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") - async def test_run_loop_wait_pod_termination_before_returning_success_event( - self, mock_hook, mock_method, trigger - ): - running_state = mock.MagicMock(**{"status.phase": "Running"}) - succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"}) - mock_hook.return_value.get_pod.side_effect = [ - self._mock_pod_result(running_state), - self._mock_pod_result(running_state), - self._mock_pod_result(succeeded_state), - ] - mock_method.return_value = ContainerState.TERMINATED - - expected_event = TriggerEvent( - { - "name": POD_NAME, - "namespace": NAMESPACE, - "status": "success", - "message": "All containers inside pod have started successfully.", - } - ) - with mock.patch.object(asyncio, "sleep") as mock_sleep: - actual_event = await (trigger.run()).asend(None) - - assert actual_event == expected_event - assert mock_sleep.call_count == 2 - @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}._get_async_hook") diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 154908a6c4b98..e695822d3863f 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -110,13 +110,7 @@ def test_serialize_should_execute_successfully(self, trigger): async def test_run_loop_return_success_event_should_execute_successfully( self, mock_hook, mock_method, trigger ): - running_state = mock.MagicMock(**{"status.phase": "Running"}) - succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"}) - mock_hook.return_value.get_pod.side_effect = [ - self._mock_pod_result(running_state), - self._mock_pod_result(running_state), - self._mock_pod_result(succeeded_state), - ] + mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( @@ -127,11 +121,9 @@ async def test_run_loop_return_success_event_should_execute_successfully( "message": "All containers inside pod have started successfully.", } ) - with mock.patch.object(asyncio, "sleep") as mock_sleep: - actual_event = await (trigger.run()).asend(None) + actual_event = await (trigger.run()).asend(None) assert actual_event == expected_event - assert mock_sleep.call_count == 2 @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")