diff --git a/airflow-core/docs/administration-and-deployment/dag-bundles.rst b/airflow-core/docs/administration-and-deployment/dag-bundles.rst index 60d799e4c2211..da71b20fdd3ec 100644 --- a/airflow-core/docs/administration-and-deployment/dag-bundles.rst +++ b/airflow-core/docs/administration-and-deployment/dag-bundles.rst @@ -53,6 +53,9 @@ Airflow supports multiple types of Dag Bundles, each catering to specific use ca **airflow.providers.amazon.aws.bundles.s3.S3DagBundle** These bundles reference an S3 bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code. +**airflow.providers.google.cloud.bundles.gcs.GCSDagBundle** + These bundles reference a GCS bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code. + Configuring Dag bundles ----------------------- diff --git a/providers/google/src/airflow/providers/google/cloud/bundles/__init__.py b/providers/google/src/airflow/providers/google/cloud/bundles/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/bundles/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/src/airflow/providers/google/cloud/bundles/gcs.py b/providers/google/src/airflow/providers/google/cloud/bundles/gcs.py new file mode 100644 index 0000000000000..0039030039a7b --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/bundles/gcs.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from pathlib import Path + +import structlog +from google.api_core.exceptions import NotFound + +from airflow.dag_processing.bundles.base import BaseDagBundle +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class GCSDagBundle(BaseDagBundle): + """ + GCS Dag bundle - exposes a directory in GCS as a Dag bundle. + + This allows Airflow to load Dags directly from a GCS bucket. + + :param gcp_conn_id: Airflow connection ID for GCS. Defaults to GoogleBaseHook.default_conn_name. + :param bucket_name: The name of the GCS bucket containing the Dag files. + :param prefix: Optional subdirectory within the GCS bucket where the Dags are stored. + If None, Dags are assumed to be at the root of the bucket (Optional). + """ + + supports_versioning = False + + def __init__( + self, + *, + gcp_conn_id: str = GoogleBaseHook.default_conn_name, + bucket_name: str, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.gcp_conn_id = gcp_conn_id + self.bucket_name = bucket_name + self.prefix = prefix + # Local path where GCS Dags are downloaded + self.gcs_dags_dir: Path = self.base_dir + + log = structlog.get_logger(__name__) + self._log = log.bind( + bundle_name=self.name, + version=self.version, + bucket_name=self.bucket_name, + prefix=self.prefix, + gcp_conn_id=self.gcp_conn_id, + ) + self._gcs_hook: GCSHook | None = None + + def _initialize(self): + with self.lock(): + if not self.gcs_dags_dir.exists(): + self._log.info("Creating local Dags directory: %s", self.gcs_dags_dir) + os.makedirs(self.gcs_dags_dir) + + if not self.gcs_dags_dir.is_dir(): + raise NotADirectoryError(f"Local Dags path: {self.gcs_dags_dir} is not a directory.") + + try: + self.gcs_hook.get_bucket(bucket_name=self.bucket_name) + except NotFound: + raise ValueError(f"GCS bucket '{self.bucket_name}' does not exist.") + + if self.prefix: + # don't check when prefix is "" + if not self.gcs_hook.list(bucket_name=self.bucket_name, prefix=self.prefix): + raise ValueError(f"GCS prefix 'gs://{self.bucket_name}/{self.prefix}' does not exist.") + self.refresh() + + def initialize(self) -> None: + self._initialize() + super().initialize() + + @property + def gcs_hook(self): + if self._gcs_hook is None: + try: + self._gcs_hook: GCSHook = GCSHook(gcp_conn_id=self.gcp_conn_id) # Initialize GCS hook. + except AirflowException as e: + self._log.warning("Could not create GCSHook for connection %s: %s", self.gcp_conn_id, e) + return self._gcs_hook + + def __repr__(self): + return ( + f"" + ) + + def get_current_version(self) -> str | None: + """Return the current version of the Dag bundle. Currently not supported.""" + return None + + @property + def path(self) -> Path: + """Return the local path to the Dag files.""" + return self.gcs_dags_dir # Path where Dags are downloaded. + + def refresh(self) -> None: + """Refresh the Dag bundle by re-downloading the Dags from GCS.""" + if self.version: + raise ValueError("Refreshing a specific version is not supported") + + with self.lock(): + self._log.debug( + "Downloading Dags from gs://%s/%s to %s", self.bucket_name, self.prefix, self.gcs_dags_dir + ) + self.gcs_hook.sync_to_local_dir( + bucket_name=self.bucket_name, + prefix=self.prefix, + local_dir=self.gcs_dags_dir, + delete_stale=True, + ) + + def view_url(self, version: str | None = None) -> str | None: + """ + Return a URL for viewing the Dags in GCS. Currently, versioning is not supported. + + This method is deprecated and will be removed when the minimum supported Airflow version is 3.1. + Use `view_url_template` instead. + """ + return self.view_url_template() + + def view_url_template(self) -> str | None: + """Return a URL for viewing the Dags in GCS. Currently, versioning is not supported.""" + if self.version: + raise ValueError("GCS url with version is not supported") + if hasattr(self, "_view_url_template") and self._view_url_template: + # Because we use this method in the view_url method, we need to handle + # backward compatibility for Airflow versions that doesn't have the + # _view_url_template attribute. Should be removed when we drop support for Airflow 3.0 + return self._view_url_template + # https://console.cloud.google.com/storage/browser// + url = f"https://console.cloud.google.com/storage/browser/{self.bucket_name}" + if self.prefix: + url += f"/{self.prefix}" + + return url diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py b/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py index a063b39b1a35f..e05aab2bc9293 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py @@ -28,8 +28,10 @@ import warnings from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager +from datetime import datetime from functools import partial from io import BytesIO +from pathlib import Path from tempfile import NamedTemporaryFile from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload from urllib.parse import urlsplit @@ -50,12 +52,14 @@ GoogleBaseAsyncHook, GoogleBaseHook, ) -from airflow.utils import timezone + +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] from airflow.version import version if TYPE_CHECKING: - from datetime import datetime - from aiohttp import ClientSession from google.api_core.retry import Retry from google.cloud.storage.blob import Blob @@ -1249,6 +1253,106 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec self.log.info("Completed successfully.") + def _sync_to_local_dir_delete_stale_local_files(self, current_gcs_objects: List[Path], local_dir: Path): + current_gcs_keys = {key.resolve() for key in current_gcs_objects} + + for item in local_dir.rglob("*"): + if item.is_file(): + if item.resolve() not in current_gcs_keys: + self.log.debug("Deleting stale local file: %s", item) + item.unlink() + # Clean up empty directories + for root, dirs, _ in os.walk(local_dir, topdown=False): + for d in dirs: + dir_path = os.path.join(root, d) + if not os.listdir(dir_path): + self.log.debug("Deleting stale empty directory: %s", dir_path) + os.rmdir(dir_path) + + def _sync_to_local_dir_if_changed(self, blob: Blob, local_target_path: Path): + should_download = False + download_msg = "" + if not local_target_path.exists(): + should_download = True + download_msg = f"Local file {local_target_path} does not exist." + else: + local_stats = local_target_path.stat() + # Reload blob to get fresh metadata, including size and updated time + blob.reload() + + if blob.size != local_stats.st_size: + should_download = True + download_msg = ( + f"GCS object size ({blob.size}) and local file size ({local_stats.st_size}) differ." + ) + + gcs_last_modified = blob.updated + if ( + not should_download + and gcs_last_modified + and local_stats.st_mtime < gcs_last_modified.timestamp() + ): + should_download = True + download_msg = f"GCS object last modified ({gcs_last_modified}) is newer than local file last modified ({datetime.fromtimestamp(local_stats.st_mtime, tz=timezone.utc)})." + + if should_download: + self.log.debug("%s Downloading %s to %s", download_msg, blob.name, local_target_path.as_posix()) + self.download( + bucket_name=blob.bucket.name, object_name=blob.name, filename=str(local_target_path) + ) + else: + self.log.debug( + "Local file %s is up-to-date with GCS object %s. Skipping download.", + local_target_path.as_posix(), + blob.name, + ) + + def sync_to_local_dir( + self, + bucket_name: str, + local_dir: str | Path, + prefix: str | None = None, + delete_stale: bool = False, + ) -> None: + """ + Download files from a GCS bucket to a local directory. + + It will download all files from the given ``prefix`` and create the corresponding + directory structure in the ``local_dir``. + + If ``delete_stale`` is ``True``, it will delete all local files that do not exist in the GCS bucket. + + :param bucket_name: The name of the GCS bucket. + :param local_dir: The local directory to which the files will be downloaded. + :param prefix: The prefix of the files to be downloaded. + :param delete_stale: If ``True``, deletes local files that don't exist in the bucket. + """ + prefix = prefix or "" + local_dir_path = Path(local_dir) + self.log.debug("Downloading data from gs://%s/%s to %s", bucket_name, prefix, local_dir_path) + + gcs_bucket = self.get_bucket(bucket_name) + local_gcs_objects = [] + + for blob in gcs_bucket.list_blobs(prefix=prefix): + # GCS lists "directories" as objects ending with a slash. We should skip them. + if blob.name.endswith("/"): + continue + + blob_path = Path(blob.name) + local_target_path = local_dir_path.joinpath(blob_path.relative_to(prefix)) + if not local_target_path.parent.exists(): + local_target_path.parent.mkdir(parents=True, exist_ok=True) + self.log.debug("Created local directory: %s", local_target_path.parent) + + self._sync_to_local_dir_if_changed(blob=blob, local_target_path=local_target_path) + local_gcs_objects.append(local_target_path) + + if delete_stale: + self._sync_to_local_dir_delete_stale_local_files( + current_gcs_objects=local_gcs_objects, local_dir=local_dir_path + ) + def sync( self, source_bucket: str, diff --git a/providers/google/tests/unit/google/cloud/bundles/__init__.py b/providers/google/tests/unit/google/cloud/bundles/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/bundles/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/tests/unit/google/cloud/bundles/test_gcs.py b/providers/google/tests/unit/google/cloud/bundles/test_gcs.py new file mode 100644 index 0000000000000..12a9b1ca5496d --- /dev/null +++ b/providers/google/tests/unit/google/cloud/bundles/test_gcs.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, PropertyMock, call, patch + +import pytest +from google.api_core.exceptions import NotFound + +import airflow.version +from airflow.models import Connection + +from tests_common.test_utils.config import conf_vars + +if airflow.version.version.strip().startswith("3"): + from airflow.providers.google.cloud.bundles.gcs import GCSDagBundle + +GCP_CONN_ID = "gcs_dags_connection" +GCS_BUCKET_NAME = "my-airflow-dags-bucket" +GCS_BUCKET_PREFIX = "project1/dags" + + +@pytest.fixture(autouse=True) +def bundle_temp_dir(tmp_path): + with conf_vars({("dag_processor", "dag_bundle_storage_path"): str(tmp_path)}): + yield tmp_path + + +@pytest.mark.skipif(not airflow.version.version.strip().startswith("3"), reason="Airflow >=3.0.0 test") +class TestGCSDagBundle: + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=GCP_CONN_ID, + conn_type="google_cloud_platform", + ) + ) + + def test_view_url_generates_console_url(self): + bundle = GCSDagBundle( + name="test", gcp_conn_id=GCP_CONN_ID, prefix=GCS_BUCKET_PREFIX, bucket_name=GCS_BUCKET_NAME + ) + + url: str = bundle.view_url() + assert ( + url == f"https://console.cloud.google.com/storage/browser/{GCS_BUCKET_NAME}/{GCS_BUCKET_PREFIX}" + ) + + def test_view_url_template_generates_console_url(self): + bundle = GCSDagBundle( + name="test", gcp_conn_id=GCP_CONN_ID, prefix=GCS_BUCKET_PREFIX, bucket_name=GCS_BUCKET_NAME + ) + url: str = bundle.view_url_template() + assert ( + url == f"https://console.cloud.google.com/storage/browser/{GCS_BUCKET_NAME}/{GCS_BUCKET_PREFIX}" + ) + + def test_supports_versioning(self): + bundle = GCSDagBundle( + name="test", gcp_conn_id=GCP_CONN_ID, prefix=GCS_BUCKET_PREFIX, bucket_name=GCS_BUCKET_NAME + ) + assert GCSDagBundle.supports_versioning is False + + # set version, it's not supported + bundle.version = "test_version" + + with pytest.raises(ValueError, match="Refreshing a specific version is not supported"): + bundle.refresh() + with pytest.raises(ValueError, match="GCS url with version is not supported"): + bundle.view_url("test_version") + + def test_local_dags_path_is_not_a_directory(self, bundle_temp_dir): + bundle_name = "test" + # Create a file where the directory should be + file_path = bundle_temp_dir / bundle_name + file_path.touch() + + bundle = GCSDagBundle( + name=bundle_name, + gcp_conn_id=GCP_CONN_ID, + prefix="project1_dags", + bucket_name="airflow_dags", + ) + with pytest.raises(NotADirectoryError, match=f"Local Dags path: {file_path} is not a directory."): + bundle.initialize() + + def test_correct_bundle_path_used(self): + bundle = GCSDagBundle( + name="test", gcp_conn_id=GCP_CONN_ID, prefix="project1_dags", bucket_name="airflow_dags" + ) + assert str(bundle.base_dir) == str(bundle.gcs_dags_dir) + + @patch("airflow.providers.google.cloud.bundles.gcs.GCSDagBundle.gcs_hook", new_callable=PropertyMock) + def test_gcs_bucket_and_prefix_validated(self, mock_gcs_hook_property): + mock_hook = MagicMock() + mock_gcs_hook_property.return_value = mock_hook + + mock_hook.get_bucket.side_effect = NotFound("Bucket not found") + bundle = GCSDagBundle( + name="test", + gcp_conn_id=GCP_CONN_ID, + prefix="project1_dags", + bucket_name="non-existing-bucket", + ) + with pytest.raises(ValueError, match="GCS bucket 'non-existing-bucket' does not exist."): + bundle.initialize() + mock_hook.get_bucket.assert_called_once_with(bucket_name="non-existing-bucket") + + mock_hook.get_bucket.side_effect = None + mock_hook.get_bucket.return_value = True + mock_hook.list.return_value = [] + bundle = GCSDagBundle( + name="test", + gcp_conn_id=GCP_CONN_ID, + prefix="non-existing-prefix", + bucket_name=GCS_BUCKET_NAME, + ) + with pytest.raises( + ValueError, + match=f"GCS prefix 'gs://{GCS_BUCKET_NAME}/non-existing-prefix' does not exist.", + ): + bundle.initialize() + mock_hook.list.assert_called_once_with(bucket_name=GCS_BUCKET_NAME, prefix="non-existing-prefix") + + mock_hook.list.return_value = ["some/object"] + bundle = GCSDagBundle( + name="test", + gcp_conn_id=GCP_CONN_ID, + prefix=GCS_BUCKET_PREFIX, + bucket_name=GCS_BUCKET_NAME, + ) + # initialize succeeds, with correct prefix and bucket + bundle.initialize() + + mock_hook.list.reset_mock() + bundle = GCSDagBundle( + name="test", + gcp_conn_id=GCP_CONN_ID, + prefix="", + bucket_name=GCS_BUCKET_NAME, + ) + # initialize succeeds, with empty prefix + bundle.initialize() + mock_hook.list.assert_not_called() + + @patch("airflow.providers.google.cloud.bundles.gcs.GCSDagBundle.gcs_hook", new_callable=PropertyMock) + def test_refresh(self, mock_gcs_hook_property): + mock_hook = MagicMock() + mock_gcs_hook_property.return_value = mock_hook + + bundle = GCSDagBundle( + name="test", + gcp_conn_id=GCP_CONN_ID, + prefix=GCS_BUCKET_PREFIX, + bucket_name=GCS_BUCKET_NAME, + ) + bundle._log.debug = MagicMock() + download_log_call = call( + "Downloading Dags from gs://%s/%s to %s", GCS_BUCKET_NAME, GCS_BUCKET_PREFIX, bundle.gcs_dags_dir + ) + sync_call = call( + bucket_name=GCS_BUCKET_NAME, + prefix=GCS_BUCKET_PREFIX, + local_dir=bundle.gcs_dags_dir, + delete_stale=True, + ) + + bundle.initialize() + assert bundle._log.debug.call_count == 1 + assert bundle._log.debug.call_args_list == [download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 1 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call] + + bundle.refresh() + assert bundle._log.debug.call_count == 2 + assert bundle._log.debug.call_args_list == [download_log_call, download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 2 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call, sync_call] + + @patch("airflow.providers.google.cloud.bundles.gcs.GCSDagBundle.gcs_hook", new_callable=PropertyMock) + def test_refresh_without_prefix(self, mock_gcs_hook_property): + mock_hook = MagicMock() + mock_gcs_hook_property.return_value = mock_hook + + bundle = GCSDagBundle( + name="test", + gcp_conn_id=GCP_CONN_ID, + bucket_name=GCS_BUCKET_NAME, + ) + bundle._log.debug = MagicMock() + download_log_call = call( + "Downloading Dags from gs://%s/%s to %s", GCS_BUCKET_NAME, "", bundle.gcs_dags_dir + ) + sync_call = call( + bucket_name=GCS_BUCKET_NAME, prefix="", local_dir=bundle.gcs_dags_dir, delete_stale=True + ) + + assert bundle.prefix == "" + bundle.initialize() + assert bundle._log.debug.call_count == 1 + assert bundle._log.debug.call_args_list == [download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 1 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call] + + bundle.refresh() + assert bundle._log.debug.call_count == 2 + assert bundle._log.debug.call_args_list == [download_log_call, download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 2 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call, sync_call] diff --git a/providers/google/tests/unit/google/cloud/hooks/test_gcs.py b/providers/google/tests/unit/google/cloud/hooks/test_gcs.py index 96275756b9324..028ba00e19abf 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_gcs.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_gcs.py @@ -1736,6 +1736,8 @@ def _create_blob( bucket: MagicMock | None = None, kms_key_name: str | None = None, generation: int = 0, + size: int = 9, + updated: datetime | None = None, ): blob = mock.MagicMock(name=f"BLOB:{name}") blob.name = name @@ -1743,9 +1745,105 @@ def _create_blob( blob.bucket = bucket blob.kms_key_name = kms_key_name blob.generation = generation + blob.size = size + blob.updated = updated or timezone.utcnow() return blob def _create_bucket(self, name: str): bucket = mock.MagicMock(name=f"BUCKET:{name}") bucket.name = name return bucket + + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_sync_to_local_dir_behaviour(self, mock_get_conn, tmp_path): + def get_logs_string(call_args_list): + return "".join([args[0][0] % args[0][1:] for args in call_args_list]) + + test_bucket = "test_bucket" + mock_bucket = self._create_bucket(name=test_bucket) + mock_get_conn.return_value.bucket.return_value = mock_bucket + + blobs = [ + self._create_blob("dag_01.py", "C1", mock_bucket), + self._create_blob("dag_02.py", "C1", mock_bucket), + self._create_blob("subproject1/dag_a.py", "C1", mock_bucket), + self._create_blob("subproject1/dag_b.py", "C1", mock_bucket), + ] + mock_bucket.list_blobs.return_value = blobs + + sync_local_dir = tmp_path / "gcs_sync_dir" + self.gcs_hook.log.debug = MagicMock() + self.gcs_hook.download = MagicMock() + + self.gcs_hook.sync_to_local_dir( + bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list) + assert f"Downloading data from gs://{test_bucket}/ to {sync_local_dir}" in logs_string + assert f"Local file {sync_local_dir}/dag_01.py does not exist." in logs_string + assert f"Downloading dag_01.py to {sync_local_dir}/dag_01.py" in logs_string + assert f"Local file {sync_local_dir}/subproject1/dag_a.py does not exist." in logs_string + assert f"Downloading subproject1/dag_a.py to {sync_local_dir}/subproject1/dag_a.py" in logs_string + assert self.gcs_hook.download.call_count == 4 + + # Create dummy local files to simulate download + for blob in blobs: + p = sync_local_dir / blob.name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("test data") + os.utime(p, (blob.updated.timestamp(), blob.updated.timestamp())) + + # add new file to bucket and sync + self.gcs_hook.log.debug = MagicMock() + self.gcs_hook.download.reset_mock() + new_blob = self._create_blob("dag_03.py", "C1", mock_bucket) + mock_bucket.list_blobs.return_value = blobs + [new_blob] + self.gcs_hook.sync_to_local_dir( + bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list) + assert ( + f"Local file {sync_local_dir}/subproject1/dag_b.py is up-to-date with GCS object subproject1/dag_b.py. Skipping download." + in logs_string + ) + assert f"Local file {sync_local_dir}/dag_03.py does not exist." in logs_string + assert f"Downloading dag_03.py to {sync_local_dir}/dag_03.py" in logs_string + self.gcs_hook.download.assert_called_once() + (sync_local_dir / "dag_03.py").write_text("test data") + os.utime( + sync_local_dir / "dag_03.py", + (new_blob.updated.timestamp(), new_blob.updated.timestamp()), + ) + + # Test deletion of stale files + local_file_that_should_be_deleted = sync_local_dir / "file_that_should_be_deleted.py" + local_file_that_should_be_deleted.write_text("test dag") + local_folder_should_be_deleted = sync_local_dir / "local_folder_should_be_deleted" + local_folder_should_be_deleted.mkdir(exist_ok=True) + self.gcs_hook.log.debug = MagicMock() + self.gcs_hook.download.reset_mock() + self.gcs_hook.sync_to_local_dir( + bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list) + assert f"Deleting stale local file: {local_file_that_should_be_deleted.as_posix()}" in logs_string + assert f"Deleting stale empty directory: {local_folder_should_be_deleted.as_posix()}" in logs_string + assert not self.gcs_hook.download.called + + # Test update of existing file (size change) + self.gcs_hook.log.debug = MagicMock() + self.gcs_hook.download.reset_mock() + updated_blob = self._create_blob( + "dag_03.py", + "C2", + mock_bucket, + size=15, + ) + mock_bucket.list_blobs.return_value = blobs + [updated_blob] + self.gcs_hook.sync_to_local_dir( + bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list) + assert "GCS object size (15) and local file size (9) differ." in logs_string + assert f"Downloading dag_03.py to {sync_local_dir}/dag_03.py" in logs_string + self.gcs_hook.download.assert_called_once()