From 7f52f67a99fc7a0e06b0fa4d989600dfd83bc6ff Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 13:01:58 +0800 Subject: [PATCH 01/10] Add stream method for GCSRemoteLogIO --- .../google/cloud/log/gcs_task_handler.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 9736abbd9024a..c24bdf4faacd0 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -39,13 +39,13 @@ from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.file_task_handler import FileTaskHandler, RawLogStream from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI - from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo + from airflow.utils.log.file_task_handler import LogResponse, RawLogStream, StreamingLogResponse _DEFAULT_SCOPESS = frozenset( [ @@ -149,11 +149,21 @@ def no_log_found(exc): exc, "resp", {} ).get("status") == "404" - def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]: - messages = [] + def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: + messages, log_streams = self.stream(relative_path, ti) + + # for each log_stream, exhaust the generator into a string logs = [] + for log_stream in log_streams: + log_content = "".join(line for line in log_stream) + logs.append(log_content) + return messages, logs + + def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: + messages: list[str] = [] + log_streams: RawLogStream = [] remote_loc = os.path.join(self.remote_base, relative_path) - uris = [] + uris: list[str] = [] bucket, prefix = _parse_gcs_url(remote_loc) blobs = list(self.client.list_blobs(bucket_or_name=bucket, prefix=prefix)) @@ -169,13 +179,24 @@ def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMes try: for key in sorted(uris): blob = storage.Blob.from_string(key, self.client) - remote_log = blob.download_as_bytes().decode() - if remote_log: - logs.append(remote_log) + log_streams.append(self._get_log_stream(blob)) except Exception as e: if not AIRFLOW_V_3_0_PLUS: messages.append(f"Unable to read remote log {e}") - return messages, logs + return messages, log_streams + + def _get_log_stream(self, blob: storage.Blob) -> RawLogStream: + """ + Yield lines from the given GCS blob. + + :param blob: The GCS blob to read from. + :yield: Lines of the log file. + """ + stream = blob.open("r") + try: + yield from stream + finally: + stream.close() class GCSTaskHandler(FileTaskHandler, LoggingMixin): @@ -273,7 +294,7 @@ def close(self): # Mark closed so we don't double write if close is called twice self.closed = True - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: + def _read_remote_logs(self, ti, try_number, metadata=None) -> LogResponse: # Explicitly getting log relative path is necessary as the given # task instance might be different than task instance passed in # in set_context method. From 49aeddede9aa04d8def809a79621b28796440dd8 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 14:25:32 +0800 Subject: [PATCH 02/10] Fix TestGCSTaskHandler, add error handling for read --- .../google/cloud/log/gcs_task_handler.py | 15 ++++++---- .../google/cloud/log/test_gcs_task_handler.py | 30 ++++++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index c24bdf4faacd0..20909c730988a 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -151,12 +151,17 @@ def no_log_found(exc): def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: messages, log_streams = self.stream(relative_path, ti) - - # for each log_stream, exhaust the generator into a string logs = [] - for log_stream in log_streams: - log_content = "".join(line for line in log_stream) - logs.append(log_content) + + try: + # for each log_stream, exhaust the generator into a string + for log_stream in log_streams: + log_content = "".join(line for line in log_stream) + logs.append(log_content) + except Exception as e: + if not AIRFLOW_V_3_0_PLUS: + messages.append(f"Unable to read remote log {e}") + return messages, logs def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index 135d443244559..dc6f183aa90ea 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -17,6 +17,7 @@ from __future__ import annotations import copy +import io import logging import os from unittest import mock @@ -24,7 +25,8 @@ import pytest -from airflow.providers.google.cloud.log.gcs_task_handler import GCSTaskHandler +from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO, GCSTaskHandler +from airflow.sdk import BaseOperator from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import datetime @@ -33,6 +35,22 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This path only works on Airflow 3") +class TestGCSRemoteLogIO: + @pytest.fixture(autouse=True) + def setup_tests(self, create_runtime_ti): + # setup remote IO + self.base_log_folder = "local/airflow/logs" + self.gcs_log_folder = "gs://bucket/remote/log/location" + self.gcs_remote_log_io = GCSRemoteLogIO( + remote_base=self.gcs_log_folder, + base_log_folder=self.base_log_folder, + delete_local_copy=True, + ) + # setup task instance + self.ti = create_runtime_ti(BaseOperator(task_id="task_1")) + + @pytest.mark.db_test class TestGCSTaskHandler: @pytest.fixture(autouse=True) @@ -105,7 +123,9 @@ def test_should_read_logs_from_remote( mock_obj = MagicMock() mock_obj.name = "remote/log/location/1.log" mock_client.return_value.list_blobs.return_value = [mock_obj] - mock_blob.from_string.return_value.download_as_bytes.return_value = b"CONTENT" + mock_blob.from_string.return_value.open.return_value = io.TextIOWrapper( + io.BytesIO(b"CONTENT"), encoding="utf-8" + ) ti = copy.copy(self.ti) ti.state = TaskInstanceState.SUCCESS session.add(ti) @@ -137,7 +157,7 @@ def test_should_read_from_local_on_logs_read_error(self, mock_blob, mock_client, mock_obj = MagicMock() mock_obj.name = "remote/log/location/1.log" mock_client.return_value.list_blobs.return_value = [mock_obj] - mock_blob.from_string.return_value.download_as_bytes.side_effect = Exception("Failed to connect") + mock_blob.from_string.return_value.open.side_effect = Exception("Failed to connect") self.gcs_task_handler.set_context(self.ti) ti = copy.copy(self.ti) @@ -297,7 +317,9 @@ def test_close_with_delete_local_copy_conf( delete_local_copy, expected_existence_of_local_copy, ): - mock_blob.from_string.return_value.download_as_bytes.return_value = b"CONTENT" + mock_blob.from_string.return_value.open.return_value = io.TextIOWrapper( + io.BytesIO(b"CONTENT"), encoding="utf-8" + ) with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}): handler = GCSTaskHandler( base_log_folder=local_log_location, From ea03f62aa254ec6cdbf8127e4365f89e256b81bc Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 15:44:10 +0800 Subject: [PATCH 03/10] Add test_upload --- .../google/cloud/log/test_gcs_task_handler.py | 84 +++++++++++++++++-- 1 file changed, 78 insertions(+), 6 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index dc6f183aa90ea..6389b8a70e9c3 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -20,6 +20,7 @@ import io import logging import os +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock @@ -34,6 +35,9 @@ from tests_common.test_utils.db import clear_db_dags, clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +if TYPE_CHECKING: + from pathlib import Path + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This path only works on Airflow 3") class TestGCSRemoteLogIO: @@ -41,14 +45,82 @@ class TestGCSRemoteLogIO: def setup_tests(self, create_runtime_ti): # setup remote IO self.base_log_folder = "local/airflow/logs" - self.gcs_log_folder = "gs://bucket/remote/log/location" - self.gcs_remote_log_io = GCSRemoteLogIO( + self.gcs_log_folder = "bucket/airflow/logs" + self.ti = create_runtime_ti(BaseOperator(task_id="task_1")) + + @pytest.mark.parametrize( + "is_absolute", + [pytest.param(True, id="absolute"), pytest.param(False, id="relative")], + ) + @pytest.mark.parametrize( + "file_exists", + [pytest.param(True, id="file-exists"), pytest.param(False, id="file-not-exists")], + ) + @pytest.mark.parametrize( + "delete_local_copy", + [pytest.param(True, id="delete-local"), pytest.param(False, id="keep-local")], + ) + @pytest.mark.parametrize( + "mock_write_method_result", + [pytest.param(True, id="write-success"), pytest.param(False, id="write-fail")], + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + @mock.patch("shutil.rmtree") + def test_upload( + self, + mock_rmtree, + mock_blob, + mock_client, + tmp_path: Path, + is_absolute: bool, + file_exists: bool, + delete_local_copy: bool, + mock_write_method_result: bool, + ): + # setup + gcs_remote_log_io = GCSRemoteLogIO( remote_base=self.gcs_log_folder, - base_log_folder=self.base_log_folder, - delete_local_copy=True, + base_log_folder=tmp_path.as_posix(), + delete_local_copy=delete_local_copy, ) - # setup task instance - self.ti = create_runtime_ti(BaseOperator(task_id="task_1")) + if file_exists: + file_path = tmp_path / "existing.log" if is_absolute else "existing.log" + with open(tmp_path / "existing.log", "w") as f: + f.write("log content") + else: + file_path = tmp_path / "non_existing.log" + + # action + with mock.patch.object( + gcs_remote_log_io, + "write", + return_value=mock_write_method_result, + ) as mock_write_method: + gcs_remote_log_io.upload(file_path, self.ti) + + # verify + if file_exists: + mock_write_method.assert_called_once() + if delete_local_copy and mock_write_method_result: + mock_rmtree.assert_called_once_with(tmp_path.as_posix()) + else: + mock_rmtree.assert_not_called() + else: + mock_write_method.assert_not_called() + mock_rmtree.assert_not_called() + + def test_read_existing(self): + pass + + def test_read_non_existing(self): + pass + + def test_stream_existing(self): + pass + + def test_stream_non_existing(self): + pass @pytest.mark.db_test From ec41cfa01837f06b732abcf81c36bffeb9cbc9b4 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 17:11:01 +0800 Subject: [PATCH 04/10] Open stream outside of _get_log_stream, early return for read if logs is None --- .../google/cloud/log/gcs_task_handler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 20909c730988a..8bd575aacffdf 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -43,6 +43,8 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: + from io import IOBase + from airflow.models.taskinstance import TaskInstance from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI from airflow.utils.log.file_task_handler import LogResponse, RawLogStream, StreamingLogResponse @@ -151,8 +153,10 @@ def no_log_found(exc): def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: messages, log_streams = self.stream(relative_path, ti) - logs = [] + if log_streams is None: + return messages, None + logs: list[str] = [] try: # for each log_stream, exhaust the generator into a string for log_stream in log_streams: @@ -184,20 +188,20 @@ def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: try: for key in sorted(uris): blob = storage.Blob.from_string(key, self.client) - log_streams.append(self._get_log_stream(blob)) + stream = blob.open("r") + log_streams.append(self._get_log_stream(stream)) except Exception as e: if not AIRFLOW_V_3_0_PLUS: messages.append(f"Unable to read remote log {e}") return messages, log_streams - def _get_log_stream(self, blob: storage.Blob) -> RawLogStream: + def _get_log_stream(self, stream: IOBase) -> RawLogStream: """ - Yield lines from the given GCS blob. + Yield lines from the given stream. - :param blob: The GCS blob to read from. + :param stream: The opened stream to read from. :yield: Lines of the log file. """ - stream = blob.open("r") try: yield from stream finally: From 7ffa3a6b9ab4eb5c34d2d18145713cfc7f0bf3dc Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 17:16:59 +0800 Subject: [PATCH 05/10] Add test_write and test_stream_and_read_methods --- .../google/cloud/log/test_gcs_task_handler.py | 171 ++++++++++++++++-- 1 file changed, 154 insertions(+), 17 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index 6389b8a70e9c3..5a2565c48a1b3 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -20,6 +20,7 @@ import io import logging import os +from types import GeneratorType from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock @@ -39,13 +40,22 @@ from pathlib import Path +def patch_mock_client_for_list_blobs(mock_client: MagicMock, blob_names: list[str]): + mock_blobs = [] + for name in blob_names: + mock_blob = MagicMock() + mock_blob.name = name + mock_blobs.append(mock_blob) + mock_client.return_value.list_blobs.return_value = mock_blobs + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This path only works on Airflow 3") class TestGCSRemoteLogIO: @pytest.fixture(autouse=True) def setup_tests(self, create_runtime_ti): # setup remote IO self.base_log_folder = "local/airflow/logs" - self.gcs_log_folder = "bucket/airflow/logs" + self.gcs_log_folder = "gs://bucket/airflow/logs" self.ti = create_runtime_ti(BaseOperator(task_id="task_1")) @pytest.mark.parametrize( @@ -110,17 +120,146 @@ def test_upload( mock_write_method.assert_not_called() mock_rmtree.assert_not_called() - def test_read_existing(self): - pass + @pytest.mark.parametrize( + "old_log_exists", + [pytest.param(True, id="old-log-exists"), pytest.param(False, id="old-log-not-exists")], + ) + @pytest.mark.parametrize( + "upload_success", + [pytest.param(True, id="upload-success"), pytest.param(False, id="upload-fail")], + ) + @pytest.mark.parametrize( + "old_log_read_error", + [pytest.param(None, id="no-read-error"), pytest.param(Exception("Read error"), id="read-error")], + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + def test_write( + self, + mock_blob, + mock_client, + old_log_exists: bool, + upload_success: bool, + old_log_read_error: Exception | None, + ): + # setup + remote_log_location = f"{self.gcs_log_folder}/task_1/attempt_1.log" + new_log_content = "NEW LOG CONTENT" + old_log_content = "OLD LOG CONTENT" + + # mock download_as_bytes for reading old log + if old_log_read_error: + mock_blob.from_string.return_value.download_as_bytes.side_effect = old_log_read_error + elif old_log_exists: + mock_blob.from_string.return_value.download_as_bytes.return_value = old_log_content.encode() + else: + # Simulate 404 error for no log found + not_found_error = Exception("No such object: bucket/airflow/logs/task_1/attempt_1.log") + mock_blob.from_string.return_value.download_as_bytes.side_effect = not_found_error + + # mock upload_from_string + if not upload_success: + mock_blob.from_string.return_value.upload_from_string.side_effect = Exception("Upload failed") + + gcs_remote_log_io = GCSRemoteLogIO( + remote_base=self.gcs_log_folder, + base_log_folder=self.base_log_folder, + delete_local_copy=False, + ) - def test_read_non_existing(self): - pass + # action + result = gcs_remote_log_io.write(new_log_content, remote_log_location) + + # verify + assert result == upload_success + + # verify the content that was uploaded + if upload_success or not upload_success: # upload_from_string is called regardless + call_args = mock_blob.from_string.return_value.upload_from_string.call_args + if call_args: + uploaded_content = call_args[0][0] + if old_log_exists and not old_log_read_error: + assert uploaded_content == f"{old_log_content}\n{new_log_content}" + else: + assert uploaded_content == new_log_content - def test_stream_existing(self): - pass + @pytest.mark.parametrize( + "is_stream_method", + [pytest.param(True, id="is-stream"), pytest.param(False, id="not-stream")], + ) + @pytest.mark.parametrize( + "blob_names", + [ + pytest.param( + ["airflow/logs/task_1/attempt_1.log", "airflow/logs/task_1/attempt_2.log"], id="blobs-exists" + ), + pytest.param([], id="blobs-not-exists"), + ], + ) + @pytest.mark.parametrize( + "read_success", + [pytest.param(True, id="read-success"), pytest.param(False, id="read-fail")], + ) + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.cloud.storage.Blob") + def test_stream_and_read_methods( + self, + mock_blob, + mock_client, + is_stream_method: bool, + blob_names: list[str], + read_success: bool, + ): + # setup + patch_mock_client_for_list_blobs(mock_client, blob_names) + if read_success: + mock_blob.from_string.return_value.open.side_effect = lambda mode: io.TextIOWrapper( + io.BytesIO(b"LOG\nCONTENT"), encoding="utf-8" + ) + else: + mock_blob.from_string.return_value.open.side_effect = Exception("Read failed") - def test_stream_non_existing(self): - pass + gcs_remote_log_io = GCSRemoteLogIO( + remote_base=self.gcs_log_folder, + base_log_folder=self.base_log_folder, + delete_local_copy=False, + ) + # action + if is_stream_method: + messages, logs = gcs_remote_log_io.stream("airflow/logs/task_1", self.ti) + else: + messages, logs = gcs_remote_log_io.read("airflow/logs/task_1", self.ti) + + # early return for no blobs + if not blob_names: + assert messages == [] + assert logs is None + return + + # verify messages + expected_messages = [ + f"{self.gcs_log_folder}/task_1/attempt_1.log", + f"{self.gcs_log_folder}/task_1/attempt_2.log", + ] + if not AIRFLOW_V_3_0_PLUS: + expected_messages = messages.extend( + ["Found remote logs:", *[f" * {x}" for x in sorted(expected_messages)]] + ) + if not read_success and not AIRFLOW_V_3_0_PLUS: + expected_messages + [f"Unable to read remote log {Exception('Read failed')}"] + assert messages == expected_messages + + # verify logs + expected_logs = ["LOG\nCONTENT", "LOG\nCONTENT"] + if is_stream_method: + for log_stream, expected_log in zip(logs, expected_logs): + assert isinstance(log_stream, GeneratorType) + assert "".join(log_stream) == expected_log + else: + if read_success: + assert logs == expected_logs + else: + assert logs == [] @pytest.mark.db_test @@ -192,9 +331,8 @@ def test_client_conn_id_behavior(self, mock_get_cred, mock_client, mock_hook, co def test_should_read_logs_from_remote( self, mock_blob, mock_client, mock_creds, session, sdk_connection_not_found ): - mock_obj = MagicMock() - mock_obj.name = "remote/log/location/1.log" - mock_client.return_value.list_blobs.return_value = [mock_obj] + blob_name = "remote/log/location/1.log" + patch_mock_client_for_list_blobs(mock_client, [blob_name]) mock_blob.from_string.return_value.open.return_value = io.TextIOWrapper( io.BytesIO(b"CONTENT"), encoding="utf-8" ) @@ -203,7 +341,7 @@ def test_should_read_logs_from_remote( session.add(ti) session.commit() logs, metadata = self.gcs_task_handler._read(ti, self.ti.try_number) - expected_gs_uri = f"gs://bucket/{mock_obj.name}" + expected_gs_uri = f"gs://bucket/{blob_name}" mock_blob.from_string.assert_called_once_with(expected_gs_uri, mock_client.return_value) @@ -226,16 +364,15 @@ def test_should_read_logs_from_remote( @mock.patch("google.cloud.storage.Client") @mock.patch("google.cloud.storage.Blob") def test_should_read_from_local_on_logs_read_error(self, mock_blob, mock_client, mock_creds): - mock_obj = MagicMock() - mock_obj.name = "remote/log/location/1.log" - mock_client.return_value.list_blobs.return_value = [mock_obj] + blob_name = "remote/log/location/1.log" + patch_mock_client_for_list_blobs(mock_client, [blob_name]) mock_blob.from_string.return_value.open.side_effect = Exception("Failed to connect") self.gcs_task_handler.set_context(self.ti) ti = copy.copy(self.ti) ti.state = TaskInstanceState.SUCCESS log, metadata = self.gcs_task_handler._read(ti, self.ti.try_number) - expected_gs_uri = f"gs://bucket/{mock_obj.name}" + expected_gs_uri = f"gs://bucket/{blob_name}" if AIRFLOW_V_3_0_PLUS: log = list(log) From 22c2d0c68cb012ca56b8fbf0e28e39185df401d3 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 17:23:12 +0800 Subject: [PATCH 06/10] Fix mistook import of RawLogStream --- .../src/airflow/providers/google/cloud/log/gcs_task_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 8bd575aacffdf..0a3dfe32c4bdb 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -39,7 +39,7 @@ from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.utils.log.file_task_handler import FileTaskHandler, RawLogStream +from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: From 3302531ee846d0d651566f9d8fda106164ced681 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 23 Dec 2025 22:43:45 +0800 Subject: [PATCH 07/10] Fix mypy error Fix missing mock for get_credentials_and_project_id Fix mypy error Fix test --- .../google/cloud/log/gcs_task_handler.py | 10 +-- .../google/cloud/log/test_gcs_task_handler.py | 62 ++++++++++++------- 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 0a3dfe32c4bdb..3e6738f58de6f 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -43,7 +43,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - from io import IOBase + from io import TextIOWrapper from airflow.models.taskinstance import TaskInstance from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI @@ -153,7 +153,7 @@ def no_log_found(exc): def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: messages, log_streams = self.stream(relative_path, ti) - if log_streams is None: + if not log_streams: return messages, None logs: list[str] = [] @@ -170,7 +170,7 @@ def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: messages: list[str] = [] - log_streams: RawLogStream = [] + log_streams: list[RawLogStream] = [] remote_loc = os.path.join(self.remote_base, relative_path) uris: list[str] = [] bucket, prefix = _parse_gcs_url(remote_loc) @@ -183,7 +183,7 @@ def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: else: messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]]) else: - return messages, None + return messages, [] try: for key in sorted(uris): @@ -195,7 +195,7 @@ def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse: messages.append(f"Unable to read remote log {e}") return messages, log_streams - def _get_log_stream(self, stream: IOBase) -> RawLogStream: + def _get_log_stream(self, stream: TextIOWrapper) -> RawLogStream: """ Yield lines from the given stream. diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index 5a2565c48a1b3..0b6a5136b95e0 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -49,7 +49,10 @@ def patch_mock_client_for_list_blobs(mock_client: MagicMock, blob_names: list[st mock_client.return_value.list_blobs.return_value = mock_blobs -@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This path only works on Airflow 3") +@mock.patch( + "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", + return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"), +) class TestGCSRemoteLogIO: @pytest.fixture(autouse=True) def setup_tests(self, create_runtime_ti): @@ -82,11 +85,12 @@ def test_upload( mock_rmtree, mock_blob, mock_client, - tmp_path: Path, - is_absolute: bool, - file_exists: bool, - delete_local_copy: bool, + mock_creds, mock_write_method_result: bool, + delete_local_copy: bool, + file_exists: bool, + is_absolute: bool, + tmp_path: Path, ): # setup gcs_remote_log_io = GCSRemoteLogIO( @@ -120,17 +124,23 @@ def test_upload( mock_write_method.assert_not_called() mock_rmtree.assert_not_called() - @pytest.mark.parametrize( - "old_log_exists", - [pytest.param(True, id="old-log-exists"), pytest.param(False, id="old-log-not-exists")], - ) @pytest.mark.parametrize( "upload_success", [pytest.param(True, id="upload-success"), pytest.param(False, id="upload-fail")], ) + @pytest.mark.parametrize( + "old_log_exists", + [ + pytest.param(True, id="old-log-exists"), + pytest.param(False, id="old-log-not-exists"), + ], + ) @pytest.mark.parametrize( "old_log_read_error", - [pytest.param(None, id="no-read-error"), pytest.param(Exception("Read error"), id="read-error")], + [ + pytest.param(None, id="old-log-read-success"), + pytest.param(Exception("Read error"), id="old-log-read-error"), + ], ) @mock.patch("google.cloud.storage.Client") @mock.patch("google.cloud.storage.Blob") @@ -138,8 +148,9 @@ def test_write( self, mock_blob, mock_client, - old_log_exists: bool, + mock_creds, upload_success: bool, + old_log_exists: bool, old_log_read_error: Exception | None, ): # setup @@ -174,7 +185,7 @@ def test_write( assert result == upload_success # verify the content that was uploaded - if upload_success or not upload_success: # upload_from_string is called regardless + if upload_success: call_args = mock_blob.from_string.return_value.upload_from_string.call_args if call_args: uploaded_content = call_args[0][0] @@ -206,9 +217,10 @@ def test_stream_and_read_methods( self, mock_blob, mock_client, - is_stream_method: bool, - blob_names: list[str], + mock_creds, read_success: bool, + blob_names: list[str], + is_stream_method: bool, ): # setup patch_mock_client_for_list_blobs(mock_client, blob_names) @@ -226,27 +238,31 @@ def test_stream_and_read_methods( ) # action if is_stream_method: - messages, logs = gcs_remote_log_io.stream("airflow/logs/task_1", self.ti) + messages, log_streams = gcs_remote_log_io.stream("airflow/logs/task_1", self.ti) + logs = log_streams # type: ignore[assignment] else: messages, logs = gcs_remote_log_io.read("airflow/logs/task_1", self.ti) # early return for no blobs if not blob_names: assert messages == [] - assert logs is None + if is_stream_method: + assert logs == [] + else: + assert logs is None return # verify messages - expected_messages = [ + expected_uris = [ f"{self.gcs_log_folder}/task_1/attempt_1.log", f"{self.gcs_log_folder}/task_1/attempt_2.log", ] - if not AIRFLOW_V_3_0_PLUS: - expected_messages = messages.extend( - ["Found remote logs:", *[f" * {x}" for x in sorted(expected_messages)]] - ) + if AIRFLOW_V_3_0_PLUS: + expected_messages = expected_uris + else: + expected_messages = ["Found remote logs:", *[f" * {x}" for x in sorted(expected_uris)]] if not read_success and not AIRFLOW_V_3_0_PLUS: - expected_messages + [f"Unable to read remote log {Exception('Read failed')}"] + expected_messages = expected_messages + ["Unable to read remote log Read failed"] assert messages == expected_messages # verify logs @@ -259,7 +275,7 @@ def test_stream_and_read_methods( if read_success: assert logs == expected_logs else: - assert logs == [] + assert logs is None @pytest.mark.db_test From d1c3216a88c737775f6bbbca1a2942e3d01a7110 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Wed, 24 Dec 2025 00:54:03 +0800 Subject: [PATCH 08/10] Fix mypy and unit test Skip 2.11 test --- .../tests/unit/google/cloud/log/test_gcs_task_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index 0b6a5136b95e0..c8922ba107402 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -28,7 +28,7 @@ import pytest from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO, GCSTaskHandler -from airflow.sdk import BaseOperator +from airflow.providers.google.version_compat import BaseOperator from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import datetime @@ -49,6 +49,7 @@ def patch_mock_client_for_list_blobs(mock_client: MagicMock, blob_names: list[st mock_client.return_value.list_blobs.return_value = mock_blobs +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3.0+") @mock.patch( "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"), @@ -241,7 +242,7 @@ def test_stream_and_read_methods( messages, log_streams = gcs_remote_log_io.stream("airflow/logs/task_1", self.ti) logs = log_streams # type: ignore[assignment] else: - messages, logs = gcs_remote_log_io.read("airflow/logs/task_1", self.ti) + messages, logs = gcs_remote_log_io.read("airflow/logs/task_1", self.ti) # type: ignore[assignment] # early return for no blobs if not blob_names: From 7600c5d03c56ec2e47e3e31d192f63ce66c9883c Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Wed, 24 Dec 2025 15:31:39 +0800 Subject: [PATCH 09/10] Fix compat test --- .../google/cloud/log/gcs_task_handler.py | 2 +- .../google/cloud/log/test_gcs_task_handler.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 3e6738f58de6f..922708f83b731 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -313,7 +313,7 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> LogResponse: if logs is None: logs = [] - if not AIRFLOW_V_3_0_PLUS: + if not AIRFLOW_V_3_0_PLUS and not messages: messages.append(f"No logs found in GCS; ti={ti}") return messages, logs diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index c8922ba107402..ad0f20b289013 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -28,7 +28,6 @@ import pytest from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO, GCSTaskHandler -from airflow.providers.google.version_compat import BaseOperator from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import datetime @@ -49,18 +48,26 @@ def patch_mock_client_for_list_blobs(mock_client: MagicMock, blob_names: list[st mock_client.return_value.list_blobs.return_value = mock_blobs -@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3.0+") +@pytest.mark.db_test @mock.patch( "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id", return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"), ) class TestGCSRemoteLogIO: @pytest.fixture(autouse=True) - def setup_tests(self, create_runtime_ti): + def setup_tests(self, create_task_instance, session): # setup remote IO self.base_log_folder = "local/airflow/logs" self.gcs_log_folder = "gs://bucket/airflow/logs" - self.ti = create_runtime_ti(BaseOperator(task_id="task_1")) + # setup TaskInstance + self.ti = ti = create_task_instance(task_id="task_1") + ti.try_number = 1 + ti.raw = False + session.add(ti) + session.commit() + yield + clear_db_runs() + clear_db_dags() @pytest.mark.parametrize( "is_absolute", From ac399b59ca4eed4f67f5e709c880b18b49592d0f Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Fri, 26 Dec 2025 14:16:50 +0800 Subject: [PATCH 10/10] Fix review comment --- .../airflow/providers/google/cloud/log/gcs_task_handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py index 922708f83b731..f6540cdf41cb0 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -159,9 +159,7 @@ def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse: logs: list[str] = [] try: # for each log_stream, exhaust the generator into a string - for log_stream in log_streams: - log_content = "".join(line for line in log_stream) - logs.append(log_content) + logs = ["".join(line for line in log_stream) for log_stream in log_streams] except Exception as e: if not AIRFLOW_V_3_0_PLUS: messages.append(f"Unable to read remote log {e}")