diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 9736abbd9024a..f6540cdf41cb0 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -43,9 +43,11 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: + from io import TextIOWrapper + from airflow.models.taskinstance import TaskInstance from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI - from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo + from airflow.utils.log.file_task_handler import LogResponse, RawLogStream, StreamingLogResponse _DEFAULT_SCOPESS = frozenset( [ @@ -149,11 +151,26 @@ def no_log_found(exc): exc, "resp", {} ).get("status") == "404" - def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]: - messages = [] - logs = [] + def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: + messages, log_streams = self.stream(relative_path, ti) + if not log_streams: + return messages, None + + logs: list[str] = [] + try: + # for each log_stream, exhaust the generator into a string + logs = ["".join(line for line in log_stream) for log_stream in log_streams] + except Exception as e: + if not AIRFLOW_V_3_0_PLUS: + messages.append(f"Unable to read remote log {e}") + + return messages, logs + + def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: + messages: list[str] = [] + log_streams: list[RawLogStream] = [] remote_loc = os.path.join(self.remote_base, relative_path) - uris = [] + uris: list[str] = [] bucket, prefix = _parse_gcs_url(remote_loc) blobs = list(self.client.list_blobs(bucket_or_name=bucket, prefix=prefix)) @@ -164,18 +181,29 @@ def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMes else: messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]]) else: - return messages, None + return messages, [] try: for key in sorted(uris): blob = storage.Blob.from_string(key, self.client) - remote_log = blob.download_as_bytes().decode() - if remote_log: - logs.append(remote_log) + stream = blob.open("r") + log_streams.append(self._get_log_stream(stream)) except Exception as e: if not AIRFLOW_V_3_0_PLUS: messages.append(f"Unable to read remote log {e}") - return messages, logs + return messages, log_streams + + def _get_log_stream(self, stream: TextIOWrapper) -> RawLogStream: + """ + Yield lines from the given stream. + + :param stream: The opened stream to read from. + :yield: Lines of the log file. + """ + try: + yield from stream + finally: + stream.close() class GCSTaskHandler(FileTaskHandler, LoggingMixin): @@ -273,7 +301,7 @@ def close(self): # Mark closed so we don't double write if close is called twice self.closed = True - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: + def _read_remote_logs(self, ti, try_number, metadata=None) -> LogResponse: # Explicitly getting log relative path is necessary as the given # task instance might be different than task instance passed in # in set_context method. @@ -283,7 +311,7 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInf if logs is None: logs = [] - if not AIRFLOW_V_3_0_PLUS: + if not AIRFLOW_V_3_0_PLUS and not messages: messages.append(f"No logs found in GCS; ti={ti}") return messages, logs diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index 135d443244559..ad0f20b289013 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -17,14 +17,17 @@ from __future__ import annotations import copy +import io import logging import os +from types import GeneratorType +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock import pytest -from airflow.providers.google.cloud.log.gcs_task_handler import GCSTaskHandler +from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO, GCSTaskHandler from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import datetime @@ -32,6 +35,256 @@ from tests_common.test_utils.db import clear_db_dags, clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if TYPE_CHECKING: + from pathlib import Path + + +def patch_mock_client_for_list_blobs(mock_client: MagicMock, blob_names: list[str]): + mock_blobs = [] + for name in blob_names: + mock_blob = MagicMock() + mock_blob.name = name + mock_blobs.append(mock_blob) + mock_client.return_value.list_blobs.return_value = mock_blobs + + +@pytest.mark.db_test +@mock.patch( + "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", + return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"), +) +class TestGCSRemoteLogIO: + @pytest.fixture(autouse=True) + def setup_tests(self, create_task_instance, session): + # setup remote IO + self.base_log_folder = "local/airflow/logs" + self.gcs_log_folder = "gs://bucket/airflow/logs" + # setup TaskInstance + self.ti = ti = create_task_instance(task_id="task_1") + ti.try_number = 1 + ti.raw = False + session.add(ti) + session.commit() + yield + clear_db_runs() + clear_db_dags() + + @pytest.mark.parametrize( + "is_absolute", + [pytest.param(True, id="absolute"), pytest.param(False, id="relative")], + ) + @pytest.mark.parametrize( + "file_exists", + [pytest.param(True, id="file-exists"), pytest.param(False, id="file-not-exists")], + ) + @pytest.mark.parametrize( + "delete_local_copy", + [pytest.param(True, id="delete-local"), pytest.param(False, id="keep-local")], + ) + @pytest.mark.parametrize( + "mock_write_method_result", + [pytest.param(True, id="write-success"), pytest.param(False, id="write-fail")], + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + @mock.patch("shutil.rmtree") + def test_upload( + self, + mock_rmtree, + mock_blob, + mock_client, + mock_creds, + mock_write_method_result: bool, + delete_local_copy: bool, + file_exists: bool, + is_absolute: bool, + tmp_path: Path, + ): + # setup + gcs_remote_log_io = GCSRemoteLogIO( + remote_base=self.gcs_log_folder, + base_log_folder=tmp_path.as_posix(), + delete_local_copy=delete_local_copy, + ) + if file_exists: + file_path = tmp_path / "existing.log" if is_absolute else "existing.log" + with open(tmp_path / "existing.log", "w") as f: + f.write("log content") + else: + file_path = tmp_path / "non_existing.log" + + # action + with mock.patch.object( + gcs_remote_log_io, + "write", + return_value=mock_write_method_result, + ) as mock_write_method: + gcs_remote_log_io.upload(file_path, self.ti) + + # verify + if file_exists: + mock_write_method.assert_called_once() + if delete_local_copy and mock_write_method_result: + mock_rmtree.assert_called_once_with(tmp_path.as_posix()) + else: + mock_rmtree.assert_not_called() + else: + mock_write_method.assert_not_called() + mock_rmtree.assert_not_called() + + @pytest.mark.parametrize( + "upload_success", + [pytest.param(True, id="upload-success"), pytest.param(False, id="upload-fail")], + ) + @pytest.mark.parametrize( + "old_log_exists", + [ + pytest.param(True, id="old-log-exists"), + pytest.param(False, id="old-log-not-exists"), + ], + ) + @pytest.mark.parametrize( + "old_log_read_error", + [ + pytest.param(None, id="old-log-read-success"), + pytest.param(Exception("Read error"), id="old-log-read-error"), + ], + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + def test_write( + self, + mock_blob, + mock_client, + mock_creds, + upload_success: bool, + old_log_exists: bool, + old_log_read_error: Exception | None, + ): + # setup + remote_log_location = f"{self.gcs_log_folder}/task_1/attempt_1.log" + new_log_content = "NEW LOG CONTENT" + old_log_content = "OLD LOG CONTENT" + + # mock download_as_bytes for reading old log + if old_log_read_error: + mock_blob.from_string.return_value.download_as_bytes.side_effect = old_log_read_error + elif old_log_exists: + mock_blob.from_string.return_value.download_as_bytes.return_value = old_log_content.encode() + else: + # Simulate 404 error for no log found + not_found_error = Exception("No such object: bucket/airflow/logs/task_1/attempt_1.log") + mock_blob.from_string.return_value.download_as_bytes.side_effect = not_found_error + + # mock upload_from_string + if not upload_success: + mock_blob.from_string.return_value.upload_from_string.side_effect = Exception("Upload failed") + + gcs_remote_log_io = GCSRemoteLogIO( + remote_base=self.gcs_log_folder, + base_log_folder=self.base_log_folder, + delete_local_copy=False, + ) + + # action + result = gcs_remote_log_io.write(new_log_content, remote_log_location) + + # verify + assert result == upload_success + + # verify the content that was uploaded + if upload_success: + call_args = mock_blob.from_string.return_value.upload_from_string.call_args + if call_args: + uploaded_content = call_args[0][0] + if old_log_exists and not old_log_read_error: + assert uploaded_content == f"{old_log_content}\n{new_log_content}" + else: + assert uploaded_content == new_log_content + + @pytest.mark.parametrize( + "is_stream_method", + [pytest.param(True, id="is-stream"), pytest.param(False, id="not-stream")], + ) + @pytest.mark.parametrize( + "blob_names", + [ + pytest.param( + ["airflow/logs/task_1/attempt_1.log", "airflow/logs/task_1/attempt_2.log"], id="blobs-exists" + ), + pytest.param([], id="blobs-not-exists"), + ], + ) + @pytest.mark.parametrize( + "read_success", + [pytest.param(True, id="read-success"), pytest.param(False, id="read-fail")], + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + def test_stream_and_read_methods( + self, + mock_blob, + mock_client, + mock_creds, + read_success: bool, + blob_names: list[str], + is_stream_method: bool, + ): + # setup + patch_mock_client_for_list_blobs(mock_client, blob_names) + if read_success: + mock_blob.from_string.return_value.open.side_effect = lambda mode: io.TextIOWrapper( + io.BytesIO(b"LOG\nCONTENT"), encoding="utf-8" + ) + else: + mock_blob.from_string.return_value.open.side_effect = Exception("Read failed") + + gcs_remote_log_io = GCSRemoteLogIO( + remote_base=self.gcs_log_folder, + base_log_folder=self.base_log_folder, + delete_local_copy=False, + ) + # action + if is_stream_method: + messages, log_streams = gcs_remote_log_io.stream("airflow/logs/task_1", self.ti) + logs = log_streams # type: ignore[assignment] + else: + messages, logs = gcs_remote_log_io.read("airflow/logs/task_1", self.ti) # type: ignore[assignment] + + # early return for no blobs + if not blob_names: + assert messages == [] + if is_stream_method: + assert logs == [] + else: + assert logs is None + return + + # verify messages + expected_uris = [ + f"{self.gcs_log_folder}/task_1/attempt_1.log", + f"{self.gcs_log_folder}/task_1/attempt_2.log", + ] + if AIRFLOW_V_3_0_PLUS: + expected_messages = expected_uris + else: + expected_messages = ["Found remote logs:", *[f" * {x}" for x in sorted(expected_uris)]] + if not read_success and not AIRFLOW_V_3_0_PLUS: + expected_messages = expected_messages + ["Unable to read remote log Read failed"] + assert messages == expected_messages + + # verify logs + expected_logs = ["LOG\nCONTENT", "LOG\nCONTENT"] + if is_stream_method: + for log_stream, expected_log in zip(logs, expected_logs): + assert isinstance(log_stream, GeneratorType) + assert "".join(log_stream) == expected_log + else: + if read_success: + assert logs == expected_logs + else: + assert logs is None + @pytest.mark.db_test class TestGCSTaskHandler: @@ -102,16 +355,17 @@ def test_client_conn_id_behavior(self, mock_get_cred, mock_client, mock_hook, co def test_should_read_logs_from_remote( self, mock_blob, mock_client, mock_creds, session, sdk_connection_not_found ): - mock_obj = MagicMock() - mock_obj.name = "remote/log/location/1.log" - mock_client.return_value.list_blobs.return_value = [mock_obj] - mock_blob.from_string.return_value.download_as_bytes.return_value = b"CONTENT" + blob_name = "remote/log/location/1.log" + patch_mock_client_for_list_blobs(mock_client, [blob_name]) + mock_blob.from_string.return_value.open.return_value = io.TextIOWrapper( + io.BytesIO(b"CONTENT"), encoding="utf-8" + ) ti = copy.copy(self.ti) ti.state = TaskInstanceState.SUCCESS session.add(ti) session.commit() logs, metadata = self.gcs_task_handler._read(ti, self.ti.try_number) - expected_gs_uri = f"gs://bucket/{mock_obj.name}" + expected_gs_uri = f"gs://bucket/{blob_name}" mock_blob.from_string.assert_called_once_with(expected_gs_uri, mock_client.return_value) @@ -134,16 +388,15 @@ def test_should_read_logs_from_remote( @mock.patch("google.cloud.storage.Client") @mock.patch("google.cloud.storage.Blob") def test_should_read_from_local_on_logs_read_error(self, mock_blob, mock_client, mock_creds): - mock_obj = MagicMock() - mock_obj.name = "remote/log/location/1.log" - mock_client.return_value.list_blobs.return_value = [mock_obj] - mock_blob.from_string.return_value.download_as_bytes.side_effect = Exception("Failed to connect") + blob_name = "remote/log/location/1.log" + patch_mock_client_for_list_blobs(mock_client, [blob_name]) + mock_blob.from_string.return_value.open.side_effect = Exception("Failed to connect") self.gcs_task_handler.set_context(self.ti) ti = copy.copy(self.ti) ti.state = TaskInstanceState.SUCCESS log, metadata = self.gcs_task_handler._read(ti, self.ti.try_number) - expected_gs_uri = f"gs://bucket/{mock_obj.name}" + expected_gs_uri = f"gs://bucket/{blob_name}" if AIRFLOW_V_3_0_PLUS: log = list(log) @@ -297,7 +550,9 @@ def test_close_with_delete_local_copy_conf( delete_local_copy, expected_existence_of_local_copy, ): - mock_blob.from_string.return_value.download_as_bytes.return_value = b"CONTENT" + mock_blob.from_string.return_value.open.return_value = io.TextIOWrapper( + io.BytesIO(b"CONTENT"), encoding="utf-8" + ) with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}): handler = GCSTaskHandler( base_log_folder=local_log_location,