Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from __future__ import annotations

import os
import posixpath
from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -58,6 +58,7 @@ class ADLSToGCSOperator(ADLSListOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:return: List of destination GCS URIs in the format ``gs://bucket/object``

**Examples**:
The following Operator would copy a single file named
Expand Down Expand Up @@ -126,7 +127,7 @@ def __init__(
self.gzip = gzip
self.google_impersonation_chain = google_impersonation_chain

def execute(self, context: Context):
def execute(self, context: Context) -> list[str]:
# use the super to list all files in an Azure Data Lake path
files = super().execute(context)
g_hook = GCSHook(
Expand All @@ -142,23 +143,28 @@ def execute(self, context: Context):
existing_files = g_hook.list(bucket_name=bucket_name, prefix=prefix)
files = list(set(files) - set(existing_files))

destination_uris = []
if files:
hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
dest_gcs_bucket, dest_gcs_prefix = _parse_gcs_url(self.dest_gcs)

for obj in files:
with NamedTemporaryFile(mode="wb", delete=True) as f:
hook.download_file(local_path=f.name, remote_path=obj)
f.flush()
dest_gcs_bucket, dest_gcs_prefix = _parse_gcs_url(self.dest_gcs)
dest_path = os.path.join(dest_gcs_prefix, obj)
dest_path = posixpath.join(dest_gcs_prefix, obj)
self.log.info("Saving file to %s", dest_path)

g_hook.upload(
bucket_name=dest_gcs_bucket, object_name=dest_path, filename=f.name, gzip=self.gzip
)

# Build and store the destination URI
destination_uri = f"gs://{dest_gcs_bucket}/{dest_path}"
destination_uris.append(destination_uri)

self.log.info("All done, uploaded %d files to GCS", len(files))
else:
self.log.info("In sync, no files needed to be uploaded to GCS")

return files
return destination_uris
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_execute(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook):
adls_one_mock_hook.return_value.list.return_value = MOCK_FILES
adls_two_mock_hook.return_value.list.return_value = MOCK_FILES

# gcs_mock_hook.return_value.upload.side_effect = _assert_upload
uploaded_files = operator.execute(None)
gcs_mock_hook.return_value.upload.assert_has_calls(
[
Expand All @@ -99,8 +98,13 @@ def test_execute(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)

# we expect MOCK_FILES to be uploaded
assert sorted(MOCK_FILES) == sorted(uploaded_files)
# Verify that the return value is a list of destination GCS URIs
assert isinstance(uploaded_files, list)
assert len(uploaded_files) == len(MOCK_FILES)

# Verify the returned URIs match the uploaded paths
expected_uris = sorted([f"gs://test/{f}" for f in MOCK_FILES])
assert sorted(uploaded_files) == expected_uris

@mock.patch("airflow.providers.google.cloud.transfers.adls_to_gcs.AzureDataLakeHook")
@mock.patch("airflow.providers.microsoft.azure.operators.adls.AzureDataLakeHook")
Expand All @@ -121,7 +125,6 @@ def test_execute_with_gzip(self, gcs_mock_hook, adls_one_mock_hook, adls_two_moc
adls_one_mock_hook.return_value.list.return_value = MOCK_FILES
adls_two_mock_hook.return_value.list.return_value = MOCK_FILES

# gcs_mock_hook.return_value.upload.side_effect = _assert_upload
uploaded_files = operator.execute(None)
gcs_mock_hook.return_value.upload.assert_has_calls(
[
Expand All @@ -138,5 +141,10 @@ def test_execute_with_gzip(self, gcs_mock_hook, adls_one_mock_hook, adls_two_moc
any_order=True,
)

# we expect MOCK_FILES to be uploaded
assert sorted(MOCK_FILES) == sorted(uploaded_files)
# Verify that the return value is a list of destination GCS URIs
assert isinstance(uploaded_files, list)
assert len(uploaded_files) == len(MOCK_FILES)

# Verify the returned URIs match the uploaded paths
expected_uris = sorted([f"gs://test/{f}" for f in MOCK_FILES])
assert sorted(uploaded_files) == expected_uris