Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from io import TextIOWrapper

from airflow.models.taskinstance import TaskInstance
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
from airflow.utils.log.file_task_handler import LogResponse, RawLogStream, StreamingLogResponse

_DEFAULT_SCOPESS = frozenset(
[
Expand Down Expand Up @@ -149,11 +151,26 @@ def no_log_found(exc):
exc, "resp", {}
).get("status") == "404"

def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
messages = []
logs = []
def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
messages, log_streams = self.stream(relative_path, ti)
if not log_streams:
return messages, None

logs: list[str] = []
try:
# for each log_stream, exhaust the generator into a string
logs = ["".join(line for line in log_stream) for log_stream in log_streams]
except Exception as e:
if not AIRFLOW_V_3_0_PLUS:
messages.append(f"Unable to read remote log {e}")

return messages, logs

def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse:
messages: list[str] = []
log_streams: list[RawLogStream] = []
remote_loc = os.path.join(self.remote_base, relative_path)
uris = []
uris: list[str] = []
bucket, prefix = _parse_gcs_url(remote_loc)
blobs = list(self.client.list_blobs(bucket_or_name=bucket, prefix=prefix))

Expand All @@ -164,18 +181,29 @@ def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMes
else:
messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]])
else:
return messages, None
return messages, []

try:
for key in sorted(uris):
blob = storage.Blob.from_string(key, self.client)
remote_log = blob.download_as_bytes().decode()
if remote_log:
logs.append(remote_log)
stream = blob.open("r")
log_streams.append(self._get_log_stream(stream))
except Exception as e:
if not AIRFLOW_V_3_0_PLUS:
messages.append(f"Unable to read remote log {e}")
return messages, logs
return messages, log_streams

def _get_log_stream(self, stream: TextIOWrapper) -> RawLogStream:
"""
Yield lines from the given stream.

:param stream: The opened stream to read from.
:yield: Lines of the log file.
"""
try:
yield from stream
finally:
stream.close()


class GCSTaskHandler(FileTaskHandler, LoggingMixin):
Expand Down Expand Up @@ -273,7 +301,7 @@ def close(self):
# Mark closed so we don't double write if close is called twice
self.closed = True

def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]:
def _read_remote_logs(self, ti, try_number, metadata=None) -> LogResponse:
# Explicitly getting log relative path is necessary as the given
# task instance might be different than task instance passed in
# in set_context method.
Expand All @@ -283,7 +311,7 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInf

if logs is None:
logs = []
if not AIRFLOW_V_3_0_PLUS:
if not AIRFLOW_V_3_0_PLUS and not messages:
messages.append(f"No logs found in GCS; ti={ti}")

return messages, logs
Loading