Skip to content
12 changes: 10 additions & 2 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import asyncio
from collections.abc import AsyncGenerator, Generator
from typing import Any
from typing import Any, TypedDict

from botocore.exceptions import ClientError

Expand All @@ -35,6 +35,14 @@
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD = 3


class CloudWatchLogEvent(TypedDict):
"""TypedDict for CloudWatch Log Event."""

timestamp: int
message: str
ingestionTime: int


class AwsLogsHook(AwsBaseHook):
"""
Interact with Amazon CloudWatch Logs.
Expand Down Expand Up @@ -67,7 +75,7 @@ def get_log_events(
start_from_head: bool | None = None,
continuation_token: ContinuationToken | None = None,
end_time: int | None = None,
) -> Generator:
) -> Generator[CloudWatchLogEvent, None, None]:
"""
Return a generator for log items in a single stream; yields all items available at the current moment.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import os
from collections.abc import Generator
from datetime import date, datetime, timedelta, timezone
from functools import cached_property
from pathlib import Path
Expand All @@ -40,8 +41,15 @@
import structlog.typing

from airflow.models.taskinstance import TaskInstance
from airflow.providers.amazon.aws.hooks.logs import CloudWatchLogEvent
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
from airflow.utils.log.file_task_handler import (
LogMessages,
LogResponse,
LogSourceInfo,
RawLogStream,
StreamingLogResponse,
)


def json_serialize_legacy(value: Any) -> str | None:
Expand Down Expand Up @@ -163,20 +171,31 @@ def upload(self, path: os.PathLike | str, ti: RuntimeTI):
self.close()
return

def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
logs: LogMessages | None = []
def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
messages, logs = self.stream(relative_path, ti)
str_logs: list[str] = [f"{msg}\n" for group in logs for msg in group]

return messages, str_logs

def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse:
logs: list[RawLogStream] = []
messages = [
f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}"
]
try:
logs = [self.get_cloudwatch_logs(relative_path, ti)]
gen: RawLogStream = (
self._parse_log_event_as_dumped_json(event)
for event in self.get_cloudwatch_logs(relative_path, ti)
)
logs = [gen]
except Exception as e:
logs = None
messages.append(str(e))

return messages, logs

def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
def get_cloudwatch_logs(
self, stream_name: str, task_instance: RuntimeTI
) -> Generator[CloudWatchLogEvent, None, None]:
"""
Return all logs from the given log stream.

Expand All @@ -192,29 +211,22 @@ def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
if (end_date := getattr(task_instance, "end_date", None)) is None
else datetime_to_epoch_utc_ms(end_date + timedelta(seconds=30))
)
events = self.hook.get_log_events(
return self.hook.get_log_events(
log_group=self.log_group,
log_stream_name=stream_name,
end_time=end_time,
)
return "\n".join(self._event_to_str(event) for event in events)

def _event_to_dict(self, event: dict) -> dict:
def _parse_log_event_as_dumped_json(self, event: CloudWatchLogEvent) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc).isoformat()
message = event["message"]
event_msg = event["message"]
try:
message = json.loads(message)
message = json.loads(event_msg)
message["timestamp"] = event_dt
return message
except Exception:
return {"timestamp": event_dt, "event": message}
message = {"timestamp": event_dt, "event": event_msg}

def _event_to_str(self, event: dict) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
# Format a datetime object to a string in Zulu time without milliseconds.
formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
message = event["message"]
return f"[{formatted_event_dt}] {message}"
return json.dumps(message)


class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
Expand Down Expand Up @@ -291,4 +303,22 @@ def _read_remote_logs(
) -> tuple[LogSourceInfo, LogMessages]:
stream_name = self._render_filename(task_instance, try_number)
messages, logs = self.io.read(stream_name, task_instance)
return messages, logs or []

messages = [
f"Reading remote log from Cloudwatch log_group: {self.io.log_group} log_stream: {stream_name}"
]
try:
events = self.io.get_cloudwatch_logs(stream_name, task_instance)
logs = ["\n".join(self._event_to_str(event) for event in events)]
except Exception as e:
logs = []
messages.append(str(e))

return messages, logs

def _event_to_str(self, event: CloudWatchLogEvent) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
# Format a datetime object to a string in Zulu time without milliseconds.
formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
message = event["message"]
return f"[{formatted_event_dt}] {message}"
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,9 @@ def test_log_message(self):
assert metadata == [
f"Reading remote log from Cloudwatch log_group: log_group_name log_stream: {stream_name}"
]
assert logs == ['[2025-03-27T21:58:01Z] {"foo": "bar", "event": "Hi", "level": "info"}']

def test_event_to_str(self):
handler = self.subject
current_time = int(time.time()) * 1000
events = [
{"timestamp": current_time - 2000, "message": "First"},
{"timestamp": current_time - 1000, "message": "Second"},
{"timestamp": current_time, "message": "Third"},
]
assert [handler._event_to_str(event) for event in events] == (
[
f"[{get_time_str(current_time - 2000)}] First",
f"[{get_time_str(current_time - 1000)}] Second",
f"[{get_time_str(current_time)}] Third",
assert logs == [
'{"foo": "bar", "event": "Hi", "level": "info", "timestamp": "2025-03-27T21:58:01.002000+00:00"}\n'
]
)


@pytest.mark.db_test
Expand Down Expand Up @@ -426,6 +412,22 @@ def test_filename_template_for_backward_compatibility(self):
filename_template=None,
)

def test_event_to_str(self):
handler = self.cloudwatch_task_handler
current_time = int(time.time()) * 1000
events = [
{"timestamp": current_time - 2000, "message": "First"},
{"timestamp": current_time - 1000, "message": "Second"},
{"timestamp": current_time, "message": "Third"},
]
assert [handler._event_to_str(event) for event in events] == (
[
f"[{get_time_str(current_time - 2000)}] First",
f"[{get_time_str(current_time - 1000)}] Second",
f"[{get_time_str(current_time)}] Third",
]
)


def generate_log_events(conn, log_group_name, log_stream_name, log_events):
conn.create_log_group(logGroupName=log_group_name)
Expand Down