Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 107 additions & 3 deletions providers/google/src/airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import warnings
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
from urllib.parse import urlsplit
Expand All @@ -50,12 +52,14 @@
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.utils import timezone

try:
from airflow.sdk import timezone
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.version import version

if TYPE_CHECKING:
from datetime import datetime

from aiohttp import ClientSession
from google.api_core.retry import Retry
from google.cloud.storage.blob import Blob
Expand Down Expand Up @@ -1249,6 +1253,106 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec

self.log.info("Completed successfully.")

def _sync_to_local_dir_delete_stale_local_files(self, current_gcs_objects: List[Path], local_dir: Path):
current_gcs_keys = {key.resolve() for key in current_gcs_objects}

for item in local_dir.rglob("*"):
if item.is_file():
if item.resolve() not in current_gcs_keys:
self.log.debug("Deleting stale local file: %s", item)
item.unlink()
# Clean up empty directories
for root, dirs, _ in os.walk(local_dir, topdown=False):
for d in dirs:
dir_path = os.path.join(root, d)
if not os.listdir(dir_path):
self.log.debug("Deleting stale empty directory: %s", dir_path)
os.rmdir(dir_path)

def _sync_to_local_dir_if_changed(self, blob: Blob, local_target_path: Path):
should_download = False
download_msg = ""
if not local_target_path.exists():
should_download = True
download_msg = f"Local file {local_target_path} does not exist."
else:
local_stats = local_target_path.stat()
# Reload blob to get fresh metadata, including size and updated time
blob.reload()

if blob.size != local_stats.st_size:
should_download = True
download_msg = (
f"GCS object size ({blob.size}) and local file size ({local_stats.st_size}) differ."
)

gcs_last_modified = blob.updated
if (
not should_download
and gcs_last_modified
and local_stats.st_mtime < gcs_last_modified.timestamp()
):
should_download = True
download_msg = f"GCS object last modified ({gcs_last_modified}) is newer than local file last modified ({datetime.fromtimestamp(local_stats.st_mtime, tz=timezone.utc)})."

if should_download:
self.log.debug("%s Downloading %s to %s", download_msg, blob.name, local_target_path.as_posix())
self.download(
bucket_name=blob.bucket.name, object_name=blob.name, filename=str(local_target_path)
)
else:
self.log.debug(
"Local file %s is up-to-date with GCS object %s. Skipping download.",
local_target_path.as_posix(),
blob.name,
)

def sync_to_local_dir(
self,
bucket_name: str,
local_dir: str | Path,
prefix: str | None = None,
delete_stale: bool = False,
) -> None:
"""
Download files from a GCS bucket to a local directory.

It will download all files from the given ``prefix`` and create the corresponding
directory structure in the ``local_dir``.

If ``delete_stale`` is ``True``, it will delete all local files that do not exist in the GCS bucket.

:param bucket_name: The name of the GCS bucket.
:param local_dir: The local directory to which the files will be downloaded.
:param prefix: The prefix of the files to be downloaded.
:param delete_stale: If ``True``, deletes local files that don't exist in the bucket.
"""
prefix = prefix or ""
local_dir_path = Path(local_dir)
self.log.debug("Downloading data from gs://%s/%s to %s", bucket_name, prefix, local_dir_path)

gcs_bucket = self.get_bucket(bucket_name)
local_gcs_objects = []

for blob in gcs_bucket.list_blobs(prefix=prefix):
# GCS lists "directories" as objects ending with a slash. We should skip them.
if blob.name.endswith("/"):
continue

blob_path = Path(blob.name)
local_target_path = local_dir_path.joinpath(blob_path.relative_to(prefix))
if not local_target_path.parent.exists():
local_target_path.parent.mkdir(parents=True, exist_ok=True)
self.log.debug("Created local directory: %s", local_target_path.parent)

self._sync_to_local_dir_if_changed(blob=blob, local_target_path=local_target_path)
local_gcs_objects.append(local_target_path)

if delete_stale:
self._sync_to_local_dir_delete_stale_local_files(
current_gcs_objects=local_gcs_objects, local_dir=local_dir_path
)

def sync(
self,
source_bucket: str,
Expand Down
104 changes: 103 additions & 1 deletion providers/google/tests/unit/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from airflow.providers.google.cloud.hooks import gcs
from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.utils import timezone

try:
from airflow.sdk import timezone
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.version import version

from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id
Expand Down Expand Up @@ -1747,16 +1751,114 @@ def _create_blob(
bucket: MagicMock | None = None,
kms_key_name: str | None = None,
generation: int = 0,
size: int = 9,
updated: datetime | None = None,
):
blob = mock.MagicMock(name=f"BLOB:{name}")
blob.name = name
blob.crc32 = crc32
blob.bucket = bucket
blob.kms_key_name = kms_key_name
blob.generation = generation
blob.size = size
blob.updated = updated or timezone.utcnow()
return blob

def _create_bucket(self, name: str):
bucket = mock.MagicMock(name=f"BUCKET:{name}")
bucket.name = name
return bucket

@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_sync_to_local_dir_behaviour(self, mock_get_conn, tmp_path):
def get_logs_string(call_args_list):
return "".join([args[0][0] % args[0][1:] for args in call_args_list])

test_bucket = "test_bucket"
mock_bucket = self._create_bucket(name=test_bucket)
mock_get_conn.return_value.bucket.return_value = mock_bucket

blobs = [
self._create_blob("dag_01.py", "C1", mock_bucket),
self._create_blob("dag_02.py", "C1", mock_bucket),
self._create_blob("subproject1/dag_a.py", "C1", mock_bucket),
self._create_blob("subproject1/dag_b.py", "C1", mock_bucket),
]
mock_bucket.list_blobs.return_value = blobs

sync_local_dir = tmp_path / "gcs_sync_dir"
self.gcs_hook.log.debug = MagicMock()
self.gcs_hook.download = MagicMock()

self.gcs_hook.sync_to_local_dir(
bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True
)
logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list)
assert f"Downloading data from gs://{test_bucket}/ to {sync_local_dir}" in logs_string
assert f"Local file {sync_local_dir}/dag_01.py does not exist." in logs_string
assert f"Downloading dag_01.py to {sync_local_dir}/dag_01.py" in logs_string
assert f"Local file {sync_local_dir}/subproject1/dag_a.py does not exist." in logs_string
assert f"Downloading subproject1/dag_a.py to {sync_local_dir}/subproject1/dag_a.py" in logs_string
assert self.gcs_hook.download.call_count == 4

# Create dummy local files to simulate download
for blob in blobs:
p = sync_local_dir / blob.name
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text("test data")
os.utime(p, (blob.updated.timestamp(), blob.updated.timestamp()))

# add new file to bucket and sync
self.gcs_hook.log.debug = MagicMock()
self.gcs_hook.download.reset_mock()
new_blob = self._create_blob("dag_03.py", "C1", mock_bucket)
mock_bucket.list_blobs.return_value = blobs + [new_blob]
self.gcs_hook.sync_to_local_dir(
bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True
)
logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list)
assert (
f"Local file {sync_local_dir}/subproject1/dag_b.py is up-to-date with GCS object subproject1/dag_b.py. Skipping download."
in logs_string
)
assert f"Local file {sync_local_dir}/dag_03.py does not exist." in logs_string
assert f"Downloading dag_03.py to {sync_local_dir}/dag_03.py" in logs_string
self.gcs_hook.download.assert_called_once()
(sync_local_dir / "dag_03.py").write_text("test data")
os.utime(
sync_local_dir / "dag_03.py",
(new_blob.updated.timestamp(), new_blob.updated.timestamp()),
)

# Test deletion of stale files
local_file_that_should_be_deleted = sync_local_dir / "file_that_should_be_deleted.py"
local_file_that_should_be_deleted.write_text("test dag")
local_folder_should_be_deleted = sync_local_dir / "local_folder_should_be_deleted"
local_folder_should_be_deleted.mkdir(exist_ok=True)
self.gcs_hook.log.debug = MagicMock()
self.gcs_hook.download.reset_mock()
self.gcs_hook.sync_to_local_dir(
bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True
)
logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list)
assert f"Deleting stale local file: {local_file_that_should_be_deleted.as_posix()}" in logs_string
assert f"Deleting stale empty directory: {local_folder_should_be_deleted.as_posix()}" in logs_string
assert not self.gcs_hook.download.called

# Test update of existing file (size change)
self.gcs_hook.log.debug = MagicMock()
self.gcs_hook.download.reset_mock()
updated_blob = self._create_blob(
"dag_03.py",
"C2",
mock_bucket,
size=15,
)
mock_bucket.list_blobs.return_value = blobs + [updated_blob]
self.gcs_hook.sync_to_local_dir(
bucket_name=test_bucket, local_dir=sync_local_dir, prefix="", delete_stale=True
)
logs_string = get_logs_string(self.gcs_hook.log.debug.call_args_list)
assert "GCS object size (15) and local file size (9) differ." in logs_string
assert f"Downloading dag_03.py to {sync_local_dir}/dag_03.py" in logs_string
self.gcs_hook.download.assert_called_once()