Skip to content
2 changes: 1 addition & 1 deletion kubernetes-tests/tests/kubernetes_tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

class StringContainingId(str):
def __eq__(self, other):
return self in other
return self in other.strip() or self in other


class BaseK8STest:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,11 @@ def test_volume_mount(self):
task_id=str(uuid4()),
in_cluster=False,
do_xcom_push=False,
container_name_log_prefix_enabled=False,
)
context = create_context(k)
k.execute(context=context)
mock_logger.info.assert_any_call("[%s] %s", "base", StringContainingId("retrieved from mount"))
mock_logger.info.assert_any_call("%s", StringContainingId("retrieved from mount"))
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod["spec"]["containers"][0]["args"] = args
self.expected_pod["spec"]["containers"][0]["volumeMounts"] = [
Expand Down Expand Up @@ -1428,6 +1429,64 @@ def test_init_container_logs_filtered(self):
< calls_args.find(marker_from_main_container)
)

@pytest.mark.parametrize(
"log_prefix_enabled, log_formatter, expected_log_message_check",
[
pytest.param(
True,
None,
lambda marker, record_message: f"[base] {marker}" in record_message,
id="log_prefix_enabled",
),
pytest.param(
False,
None,
lambda marker, record_message: marker in record_message and "[base]" not in record_message,
id="log_prefix_disabled",
),
pytest.param(
False, # Ignored when log_formatter is provided
lambda container_name, message: f"CUSTOM[{container_name}]: {message}",
lambda marker, record_message: f"CUSTOM[base]: {marker}" in record_message,
id="custom_log_formatter",
),
],
)
def test_log_output_configurations(self, log_prefix_enabled, log_formatter, expected_log_message_check):
"""
Tests various log output configurations (container_name_log_prefix_enabled, log_formatter)
for KubernetesPodOperator.
"""
marker = f"test_log_{uuid4()}"
k = KubernetesPodOperator(
namespace="default",
image="busybox",
cmds=["sh", "-cx"],
arguments=[f"echo {marker}"],
labels={"test_label": "test"},
task_id=str(uuid4()),
in_cluster=False,
do_xcom_push=False,
get_logs=True,
container_name_log_prefix_enabled=log_prefix_enabled,
log_formatter=log_formatter,
)

# Test the _log_message method directly
logger = logging.getLogger("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager")
with mock.patch.object(logger, "info") as mock_info:
k.pod_manager._log_message(
message=marker,
container_name="base",
container_name_log_prefix_enabled=log_prefix_enabled,
log_formatter=log_formatter,
)

# Check that the message was logged with the expected format
mock_info.assert_called_once()
logged_message = mock_info.call_args[0][1] # Second argument is the message
assert expected_log_message_check(marker, logged_message)


# TODO: Task SDK: https://github.com/apache/airflow/issues/45438
@pytest.mark.skip(reason="AIP-72: Secret Masking yet to be implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class KubernetesPodOperator(BaseOperator):
resuming to fetch the latest logs. If ``None``, then the task will remain in deferred state until pod
is done, and no logs will be visible until that time.
:param trigger_kwargs: additional keyword parameters passed to the trigger
:param container_name_log_prefix_enabled: if True, will prefix container name to each log line.
Default to True.
:param log_formatter: custom log formatter function that takes two string arguments:
the first string is the container_name and the second string is the message_to_log.
The function should return a formatted string. If None, the default formatting will be used.
"""

# !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!!
Expand Down Expand Up @@ -343,6 +348,8 @@ def __init__(
progress_callback: Callable[[str], None] | None = None,
logging_interval: int | None = None,
trigger_kwargs: dict | None = None,
container_name_log_prefix_enabled: bool = True,
log_formatter: Callable[[str, str], str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -438,6 +445,8 @@ def __init__(
self._progress_callback = progress_callback
self.callbacks = [] if not callbacks else callbacks if isinstance(callbacks, list) else [callbacks]
self._killed: bool = False
self.container_name_log_prefix_enabled = container_name_log_prefix_enabled
self.log_formatter = log_formatter

@cached_property
def _incluster_namespace(self):
Expand Down Expand Up @@ -750,6 +759,8 @@ def await_init_containers_completion(self, pod: k8s.V1Pod):
pod=pod,
init_containers=self.init_container_logs,
follow_logs=True,
container_name_log_prefix_enabled=self.container_name_log_prefix_enabled,
log_formatter=self.log_formatter,
)
except kubernetes.client.exceptions.ApiException as exc:
self._handle_api_exception(exc, pod)
Expand All @@ -766,6 +777,8 @@ def await_pod_completion(self, pod: k8s.V1Pod):
pod=pod,
containers=self.container_logs,
follow_logs=True,
container_name_log_prefix_enabled=self.container_name_log_prefix_enabled,
log_formatter=self.log_formatter,
)
if not self.get_logs or (
self.container_logs is not True and self.base_container_name not in self.container_logs
Expand Down Expand Up @@ -914,6 +927,8 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
container_name=self.base_container_name,
follow=follow,
since_time=last_log_time,
container_name_log_prefix_enabled=self.container_name_log_prefix_enabled,
log_formatter=self.log_formatter,
)

self.invoke_defer_method(pod_log_status.last_log_time)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import json
import math
import time
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from contextlib import closing, suppress
from dataclasses import dataclass
from datetime import timedelta
Expand Down Expand Up @@ -456,6 +456,26 @@ async def await_pod_start(

await asyncio.sleep(check_interval)

def _log_message(
self,
message: str,
container_name: str,
container_name_log_prefix_enabled: bool,
log_formatter: Callable[[str, str], str] | None,
) -> None:
"""Log a message with appropriate formatting."""
if is_log_group_marker(message):
print(message)
else:
if log_formatter:
formatted_message = log_formatter(container_name, message)
self.log.info("%s", formatted_message)
else:
log_message = (
f"[{container_name}] {message}" if container_name_log_prefix_enabled else message
)
self.log.info("%s", log_message)

def fetch_container_logs(
self,
pod: V1Pod,
Expand All @@ -464,6 +484,8 @@ def fetch_container_logs(
follow=False,
since_time: DateTime | None = None,
post_termination_timeout: int = 120,
container_name_log_prefix_enabled: bool = True,
log_formatter: Callable[[str, str], str] | None = None,
) -> PodLoggingStatus:
"""
Follow the logs of container and stream to airflow logging.
Expand Down Expand Up @@ -529,10 +551,12 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
line=line, client=self._client, mode=ExecutionMode.SYNC
)
if message_to_log is not None:
if is_log_group_marker(message_to_log):
print(message_to_log)
else:
self.log.info("[%s] %s", container_name, message_to_log)
self._log_message(
message_to_log,
container_name,
container_name_log_prefix_enabled,
log_formatter,
)
last_captured_timestamp = message_timestamp
message_to_log = message
message_timestamp = line_timestamp
Expand All @@ -548,10 +572,9 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
line=line, client=self._client, mode=ExecutionMode.SYNC
)
if message_to_log is not None:
if is_log_group_marker(message_to_log):
print(message_to_log)
else:
self.log.info("[%s] %s", container_name, message_to_log)
self._log_message(
message_to_log, container_name, container_name_log_prefix_enabled, log_formatter
)
last_captured_timestamp = message_timestamp
except TimeoutError as e:
# in case of timeout, increment return time by 2 seconds to avoid
Expand Down Expand Up @@ -630,7 +653,12 @@ def _reconcile_requested_log_containers(
return containers_to_log

def fetch_requested_init_container_logs(
self, pod: V1Pod, init_containers: Iterable[str] | str | Literal[True] | None, follow_logs=False
self,
pod: V1Pod,
init_containers: Iterable[str] | str | Literal[True] | None,
follow_logs=False,
container_name_log_prefix_enabled: bool = True,
log_formatter: Callable[[str, str], str] | None = None,
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the specified pod and publish it to airflow logging.
Expand All @@ -650,12 +678,23 @@ def fetch_requested_init_container_logs(
containers_to_log = sorted(containers_to_log, key=lambda cn: all_containers.index(cn))
for c in containers_to_log:
self._await_init_container_start(pod=pod, container_name=c)
status = self.fetch_container_logs(pod=pod, container_name=c, follow=follow_logs)
status = self.fetch_container_logs(
pod=pod,
container_name=c,
follow=follow_logs,
container_name_log_prefix_enabled=container_name_log_prefix_enabled,
log_formatter=log_formatter,
)
pod_logging_statuses.append(status)
return pod_logging_statuses

def fetch_requested_container_logs(
self, pod: V1Pod, containers: Iterable[str] | str | Literal[True], follow_logs=False
self,
pod: V1Pod,
containers: Iterable[str] | str | Literal[True],
follow_logs=False,
container_name_log_prefix_enabled: bool = True,
log_formatter: Callable[[str, str], str] | None = None,
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the specified pod and publish it to airflow logging.
Expand All @@ -672,7 +711,13 @@ def fetch_requested_container_logs(
pod_name=pod.metadata.name,
)
for c in containers_to_log:
status = self.fetch_container_logs(pod=pod, container_name=c, follow=follow_logs)
status = self.fetch_container_logs(
pod=pod,
container_name=c,
follow=follow_logs,
container_name_log_prefix_enabled=container_name_log_prefix_enabled,
log_formatter=log_formatter,
)
pod_logging_statuses.append(status)
return pod_logging_statuses

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,13 @@ def test_get_logs_but_not_for_base_container(
pod, _ = self.run_pod(k)

# check that the base container is not included in the logs
mock_fetch_log.assert_called_once_with(pod=pod, containers=["some_init_container"], follow_logs=True)
mock_fetch_log.assert_called_once_with(
pod=pod,
containers=["some_init_container"],
follow_logs=True,
container_name_log_prefix_enabled=True,
log_formatter=None,
)
# check that KPO waits for the base container to complete before proceeding to extract XCom
mock_await_container_completion.assert_called_once_with(
pod=pod, container_name="base", polling_time=1
Expand Down Expand Up @@ -1999,7 +2005,16 @@ def test_await_container_completion_refreshes_properties_on_exception(

if get_logs:
fetch_requested_container_logs.assert_has_calls(
[mock.call(pod=pod, containers=k.container_logs, follow_logs=True)] * 3
[
mock.call(
pod=pod,
containers=k.container_logs,
follow_logs=True,
container_name_log_prefix_enabled=True,
log_formatter=None,
)
]
* 3
)
else:
mock_await_container_completion.assert_has_calls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,8 @@ def test_get_logs_from_driver(
pod=op.pod,
containers="spark-kubernetes-driver",
follow_logs=True,
container_name_log_prefix_enabled=True,
log_formatter=None,
)

@pytest.mark.asyncio
Expand Down
Loading