Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ Airflow supports multiple types of Dag Bundles, each catering to specific use ca
**airflow.providers.amazon.aws.bundles.s3.S3DagBundle**
These bundles reference an S3 bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code.

**airflow.providers.google.cloud.bundles.gcs.GCSDagBundle**
These bundles reference a GCS bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code.

Configuring Dag bundles
-----------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
161 changes: 161 additions & 0 deletions providers/google/src/airflow/providers/google/cloud/bundles/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
from pathlib import Path

import structlog
from google.api_core.exceptions import NotFound

from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook


class GCSDagBundle(BaseDagBundle):
"""
GCS Dag bundle - exposes a directory in GCS as a Dag bundle.

This allows Airflow to load Dags directly from a GCS bucket.

:param gcp_conn_id: Airflow connection ID for GCS. Defaults to GoogleBaseHook.default_conn_name.
:param bucket_name: The name of the GCS bucket containing the Dag files.
:param prefix: Optional subdirectory within the GCS bucket where the Dags are stored.
If None, Dags are assumed to be at the root of the bucket (Optional).
"""

supports_versioning = False

def __init__(
self,
*,
gcp_conn_id: str = GoogleBaseHook.default_conn_name,
bucket_name: str,
prefix: str = "",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.gcp_conn_id = gcp_conn_id
self.bucket_name = bucket_name
self.prefix = prefix
# Local path where GCS Dags are downloaded
self.gcs_dags_dir: Path = self.base_dir

log = structlog.get_logger(__name__)
self._log = log.bind(
bundle_name=self.name,
version=self.version,
bucket_name=self.bucket_name,
prefix=self.prefix,
gcp_conn_id=self.gcp_conn_id,
)
self._gcs_hook: GCSHook | None = None

def _initialize(self):
with self.lock():
if not self.gcs_dags_dir.exists():
self._log.info("Creating local Dags directory: %s", self.gcs_dags_dir)
os.makedirs(self.gcs_dags_dir)

if not self.gcs_dags_dir.is_dir():
raise NotADirectoryError(f"Local Dags path: {self.gcs_dags_dir} is not a directory.")

try:
self.gcs_hook.get_bucket(bucket_name=self.bucket_name)
except NotFound:
raise ValueError(f"GCS bucket '{self.bucket_name}' does not exist.")

if self.prefix:
# don't check when prefix is ""
if not self.gcs_hook.list(bucket_name=self.bucket_name, prefix=self.prefix):
raise ValueError(f"GCS prefix 'gs://{self.bucket_name}/{self.prefix}' does not exist.")
self.refresh()

def initialize(self) -> None:
self._initialize()
super().initialize()

@property
def gcs_hook(self):
if self._gcs_hook is None:
try:
self._gcs_hook: GCSHook = GCSHook(gcp_conn_id=self.gcp_conn_id) # Initialize GCS hook.
except AirflowException as e:
self._log.warning("Could not create GCSHook for connection %s: %s", self.gcp_conn_id, e)
return self._gcs_hook

def __repr__(self):
return (
f"<GCSDagBundle("
f"name={self.name!r}, "
f"bucket_name={self.bucket_name!r}, "
f"prefix={self.prefix!r}, "
f"version={self.version!r}"
f")>"
)

def get_current_version(self) -> str | None:
"""Return the current version of the Dag bundle. Currently not supported."""
return None

@property
def path(self) -> Path:
"""Return the local path to the Dag files."""
return self.gcs_dags_dir # Path where Dags are downloaded.

def refresh(self) -> None:
"""Refresh the Dag bundle by re-downloading the Dags from GCS."""
if self.version:
raise ValueError("Refreshing a specific version is not supported")

with self.lock():
self._log.debug(
"Downloading Dags from gs://%s/%s to %s", self.bucket_name, self.prefix, self.gcs_dags_dir
)
self.gcs_hook.sync_to_local_dir(
bucket_name=self.bucket_name,
prefix=self.prefix,
local_dir=self.gcs_dags_dir,
delete_stale=True,
)

def view_url(self, version: str | None = None) -> str | None:
"""
Return a URL for viewing the Dags in GCS. Currently, versioning is not supported.

This method is deprecated and will be removed when the minimum supported Airflow version is 3.1.
Use `view_url_template` instead.
"""
return self.view_url_template()

def view_url_template(self) -> str | None:
"""Return a URL for viewing the Dags in GCS. Currently, versioning is not supported."""
if self.version:
raise ValueError("GCS url with version is not supported")
if hasattr(self, "_view_url_template") and self._view_url_template:
# Because we use this method in the view_url method, we need to handle
# backward compatibility for Airflow versions that doesn't have the
# _view_url_template attribute. Should be removed when we drop support for Airflow 3.0
return self._view_url_template
# https://console.cloud.google.com/storage/browser/<bucket-name>/<prefix>
url = f"https://console.cloud.google.com/storage/browser/{self.bucket_name}"
if self.prefix:
url += f"/{self.prefix}"

return url
110 changes: 107 additions & 3 deletions providers/google/src/airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import warnings
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
from urllib.parse import urlsplit
Expand All @@ -50,12 +52,14 @@
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.utils import timezone

try:
from airflow.sdk import timezone
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.version import version

if TYPE_CHECKING:
from datetime import datetime

from aiohttp import ClientSession
from google.api_core.retry import Retry
from google.cloud.storage.blob import Blob
Expand Down Expand Up @@ -1249,6 +1253,106 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec

self.log.info("Completed successfully.")

def _sync_to_local_dir_delete_stale_local_files(self, current_gcs_objects: List[Path], local_dir: Path):
current_gcs_keys = {key.resolve() for key in current_gcs_objects}

for item in local_dir.rglob("*"):
if item.is_file():
if item.resolve() not in current_gcs_keys:
self.log.debug("Deleting stale local file: %s", item)
item.unlink()
# Clean up empty directories
for root, dirs, _ in os.walk(local_dir, topdown=False):
for d in dirs:
dir_path = os.path.join(root, d)
if not os.listdir(dir_path):
self.log.debug("Deleting stale empty directory: %s", dir_path)
os.rmdir(dir_path)

def _sync_to_local_dir_if_changed(self, blob: Blob, local_target_path: Path):
should_download = False
download_msg = ""
if not local_target_path.exists():
should_download = True
download_msg = f"Local file {local_target_path} does not exist."
else:
local_stats = local_target_path.stat()
# Reload blob to get fresh metadata, including size and updated time
blob.reload()

if blob.size != local_stats.st_size:
should_download = True
download_msg = (
f"GCS object size ({blob.size}) and local file size ({local_stats.st_size}) differ."
)

gcs_last_modified = blob.updated
if (
not should_download
and gcs_last_modified
and local_stats.st_mtime < gcs_last_modified.timestamp()
):
should_download = True
download_msg = f"GCS object last modified ({gcs_last_modified}) is newer than local file last modified ({datetime.fromtimestamp(local_stats.st_mtime, tz=timezone.utc)})."

if should_download:
self.log.debug("%s Downloading %s to %s", download_msg, blob.name, local_target_path.as_posix())
self.download(
bucket_name=blob.bucket.name, object_name=blob.name, filename=str(local_target_path)
)
else:
self.log.debug(
"Local file %s is up-to-date with GCS object %s. Skipping download.",
local_target_path.as_posix(),
blob.name,
)

def sync_to_local_dir(
self,
bucket_name: str,
local_dir: str | Path,
prefix: str | None = None,
delete_stale: bool = False,
) -> None:
"""
Download files from a GCS bucket to a local directory.

It will download all files from the given ``prefix`` and create the corresponding
directory structure in the ``local_dir``.

If ``delete_stale`` is ``True``, it will delete all local files that do not exist in the GCS bucket.

:param bucket_name: The name of the GCS bucket.
:param local_dir: The local directory to which the files will be downloaded.
:param prefix: The prefix of the files to be downloaded.
:param delete_stale: If ``True``, deletes local files that don't exist in the bucket.
"""
prefix = prefix or ""
local_dir_path = Path(local_dir)
self.log.debug("Downloading data from gs://%s/%s to %s", bucket_name, prefix, local_dir_path)

gcs_bucket = self.get_bucket(bucket_name)
local_gcs_objects = []

for blob in gcs_bucket.list_blobs(prefix=prefix):
# GCS lists "directories" as objects ending with a slash. We should skip them.
if blob.name.endswith("/"):
continue

blob_path = Path(blob.name)
local_target_path = local_dir_path.joinpath(blob_path.relative_to(prefix))
if not local_target_path.parent.exists():
local_target_path.parent.mkdir(parents=True, exist_ok=True)
self.log.debug("Created local directory: %s", local_target_path.parent)

self._sync_to_local_dir_if_changed(blob=blob, local_target_path=local_target_path)
local_gcs_objects.append(local_target_path)

if delete_stale:
self._sync_to_local_dir_delete_stale_local_files(
current_gcs_objects=local_gcs_objects, local_dir=local_dir_path
)

def sync(
self,
source_bucket: str,
Expand Down
16 changes: 16 additions & 0 deletions providers/google/tests/unit/google/cloud/bundles/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading
Loading