Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -735,20 +735,9 @@ def execute_sync(self, context: Context):
)
finally:
pod_to_clean = self.pod or self.pod_request_obj
self.cleanup(
pod=pod_to_clean,
remote_pod=self.remote_pod,
xcom_result=result,
context=context,
self.post_complete_action(
pod=pod_to_clean, remote_pod=self.remote_pod, context=context, result=result
)
for callback in self.callbacks:
callback.on_pod_cleanup(
pod=pod_to_clean,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

if self.do_xcom_push:
return result
Expand Down Expand Up @@ -819,11 +808,20 @@ def _refresh_cached_properties(self):
def execute_async(self, context: Context) -> None:
if self.pod_request_obj is None:
self.pod_request_obj = self.build_pod_request_obj(context)
for callback in self.callbacks:
callback.on_pod_manifest_created(
pod_request=self.pod_request_obj,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
if self.pod is None:
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
context=context,
)

if self.callbacks:
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
Expand Down Expand Up @@ -886,6 +884,7 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
grab the latest logs and defer back to the trigger again.
"""
self.pod = None
xcom_sidecar_output = None
try:
pod_name = event["name"]
pod_namespace = event["namespace"]
Expand All @@ -909,20 +908,37 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
follow = self.logging_interval is None
last_log_time = event.get("last_log_time")

if event["status"] in ("error", "failed", "timeout"):
event_message = event.get("message", "No message provided")
self.log.error(
"Trigger emitted an %s event, failing the task: %s", event["status"], event_message
)
# fetch some logs when pod is failed
if event["status"] in ("error", "failed", "timeout", "success"):
if self.get_logs:
self._write_logs(self.pod, follow=follow, since_time=last_log_time)

if self.do_xcom_push:
_ = self.extract_xcom(pod=self.pod)
for callback in self.callbacks:
callback.on_pod_completion(
pod=self.pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
for callback in self.callbacks:
callback.on_pod_teardown(
pod=self.pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

xcom_sidecar_output = self.extract_xcom(pod=self.pod) if self.do_xcom_push else None

if event["status"] != "success":
self.log.error(
"Trigger emitted an %s event, failing the task: %s", event["status"], event["message"]
)
message = event.get("stack_trace", event["message"])
raise AirflowException(message)

message = event.get("stack_trace", event["message"])
raise AirflowException(message)
return xcom_sidecar_output

if event["status"] == "running":
if self.get_logs:
Expand All @@ -940,22 +956,12 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
self.invoke_defer_method(pod_log_status.last_log_time)
else:
self.invoke_defer_method()

elif event["status"] == "success":
# fetch some logs when pod is executed successfully
if self.get_logs:
self._write_logs(self.pod, follow=follow, since_time=last_log_time)

if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=self.pod)
return xcom_sidecar_output
return
except TaskDeferred:
raise
finally:
self._clean(event, context)
self._clean(event=event, context=context, result=xcom_sidecar_output)

def _clean(self, event: dict[str, Any], context: Context) -> None:
def _clean(self, event: dict[str, Any], result: dict | None, context: Context) -> None:
if event["status"] == "running":
return
istio_enabled = self.is_istio_enabled(self.pod)
Expand All @@ -979,6 +985,7 @@ def _clean(self, event: dict[str, Any], context: Context) -> None:
pod=self.pod,
remote_pod=self.pod,
context=context,
result=result,
)

def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
Expand Down Expand Up @@ -1008,11 +1015,15 @@ def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime
e if not isinstance(e, ApiException) else e.reason,
)

def post_complete_action(self, *, pod, remote_pod, context: Context, **kwargs) -> None:
def post_complete_action(
self, *, pod: k8s.V1Pod, remote_pod: k8s.V1Pod, context: Context, result: dict | None, **kwargs
) -> None:
"""Actions that must be done after operator finishes logic of the deferrable_execution."""
self.cleanup(
pod=pod,
remote_pod=remote_pod,
xcom_result=result,
context=context,
)
for callback in self.callbacks:
callback.on_pod_cleanup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,84 @@ def test_trigger_error(self, find_pod, cleanup, mock_write_log):
},
)

@patch(HOOK_CLASS)
def test_execute_async_callbacks(self, mocked_hook):
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode

from unit.cncf.kubernetes.test_callbacks import (
MockKubernetesPodOperatorCallback,
MockWrapper,
)

MockWrapper.reset()
mock_callbacks = MockWrapper.mock_callbacks
remote_pod_mock = MagicMock()
remote_pod_mock.status.phase = "Succeeded"
self.await_pod_mock.return_value = remote_pod_mock
mocked_hook.return_value.get_pod.return_value = remote_pod_mock

k = KubernetesPodOperator(
namespace="default",
image="ubuntu:16.04",
cmds=["bash", "-cx"],
arguments=["echo 10"],
labels={"foo": "bar"},
name="test",
task_id="task",
do_xcom_push=False,
callbacks=MockKubernetesPodOperatorCallback,
)
context = create_context(k)

callback_event = {
"status": "success",
"message": TEST_SUCCESS_MESSAGE,
"name": TEST_NAME,
"namespace": TEST_NAMESPACE,
}
k.trigger_reentry(context=context, event=callback_event)

# check on_operator_resuming callback
mock_callbacks.on_operator_resuming.assert_called_once()
assert mock_callbacks.on_operator_resuming.call_args.kwargs == {
"client": k.client,
"mode": ExecutionMode.SYNC,
"pod": remote_pod_mock,
"operator": k,
"context": context,
"event": callback_event,
}

# check on_pod_cleanup callback
mock_callbacks.on_pod_cleanup.assert_called_once()
assert mock_callbacks.on_pod_cleanup.call_args.kwargs == {
"client": k.client,
"mode": ExecutionMode.SYNC,
"pod": remote_pod_mock,
"operator": k,
"context": context,
}

# check on_pod_completion callback
mock_callbacks.on_pod_completion.assert_called_once()
assert mock_callbacks.on_pod_completion.call_args.kwargs == {
"client": k.client,
"mode": ExecutionMode.SYNC,
"pod": remote_pod_mock,
"operator": k,
"context": context,
}

# check on_pod_teardown callback
mock_callbacks.on_pod_teardown.assert_called_once()
assert mock_callbacks.on_pod_teardown.call_args.kwargs == {
"client": k.client,
"mode": ExecutionMode.SYNC,
"pod": remote_pod_mock,
"operator": k,
"context": context,
}


@pytest.mark.parametrize("do_xcom_push", [True, False])
@patch(KUB_OP_PATH.format("extract_xcom"))
Expand Down
Loading