Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __init__(
self.exact_match = exact_match
self.match_glob = match_glob

def execute(self, context: Context):
def execute(self, context: Context) -> list[str]:
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -274,13 +274,22 @@ def execute(self, context: Context):
raise AirflowException("You can't have two empty strings inside source_object")

# Iterate over the source_objects and do the copy
destination_uris_result: list[str] = []
for prefix in self.source_objects:
# Check if prefix contains wildcard
if WILDCARD in prefix:
self._copy_source_with_wildcard(hook=hook, prefix=prefix)
destination_uris_result.extend(self._copy_source_with_wildcard(hook=hook, prefix=prefix))
# Now search with prefix using provided delimiter if any
else:
self._copy_source_without_wildcard(hook=hook, prefix=prefix)
destination_uris_result.extend(self._copy_source_without_wildcard(hook=hook, prefix=prefix))

# Deduplicate while preserving order. Same destination URI can appear when multiple
# source paths map to one file, e.g. source_objects=["src/foo.png", "src/data/"] with
# data/ containing foo.png yields backup/list/foo.png twice.
destination_uris_result = list(dict.fromkeys(destination_uris_result))

# Exclude directory-like URIs (ending with /) from return; copy behavior is unchanged.
return [u for u in destination_uris_result if not u.endswith("/")]

def _ignore_existing_files(self, hook, prefix, **kwargs):
# list all files in the Destination GCS bucket
Expand Down Expand Up @@ -312,10 +321,12 @@ def _ignore_existing_files(self, hook, prefix, **kwargs):
self.log.info("There are no new files to sync. Have a nice day!")
return objects

def _copy_source_without_wildcard(self, hook, prefix):
def _copy_source_without_wildcard(self, hook, prefix) -> list[str]:
"""
List all files in source_objects, copy files to destination_object, and rename each source file.

:return: List of destination URIs (gs://bucket/object) for files that were copied.

For source_objects with no wildcard, this operator would first list
all files in source_objects, using provided delimiter if any. Then copy
files from source_objects to destination_object and rename each source
Expand Down Expand Up @@ -399,32 +410,37 @@ def _copy_source_without_wildcard(self, hook, prefix):
hook, prefix, objects=objects, delimiter=self.delimiter, match_glob=self.match_glob
)

result_uris: list[str] = []
# If objects is empty, and we have prefix, let's check if prefix is a blob
# and copy directly
if len(objects) == 0 and prefix:
if hook.exists(self.source_bucket, prefix):
self._copy_single_object(
uri = self._copy_single_object(
hook=hook, source_object=prefix, destination_object=self.destination_object
)
if uri:
result_uris.append(uri)
elif self.source_object_required:
msg = f"{prefix} does not exist in bucket {self.source_bucket}"
self.log.warning(msg)
raise AirflowException(msg)
if len(objects) == 1 and objects[0][-1] != "/":
self._copy_file(hook=hook, source_object=objects[0])
result_uris.extend(self._copy_file(hook=hook, source_object=objects[0]))
elif len(objects):
self._copy_multiple_objects(hook=hook, source_objects=objects, prefix=prefix)
result_uris.extend(self._copy_multiple_objects(hook=hook, source_objects=objects, prefix=prefix))
return result_uris

def _copy_file(self, hook, source_object):
def _copy_file(self, hook, source_object) -> list[str]:
destination_object = self.destination_object or source_object
if self.destination_object and self.destination_object[-1] == "/":
file_name = source_object.split("/")[-1]
destination_object += file_name
self._copy_single_object(
uri = self._copy_single_object(
hook=hook, source_object=source_object, destination_object=destination_object
)
return [uri] if uri else []

def _copy_multiple_objects(self, hook, source_objects, prefix):
def _copy_multiple_objects(self, hook, source_objects, prefix) -> list[str]:
# Check whether the prefix is a root directory for all the rest of objects.
_pref = prefix.rstrip("/")
is_directory = prefix.endswith("/") or all(
Expand All @@ -436,6 +452,7 @@ def _copy_multiple_objects(self, hook, source_objects, prefix):
else:
base_path = prefix[0 : prefix.rfind("/") + 1] if "/" in prefix else ""

result_uris: list[str] = []
for source_obj in source_objects:
if not self._check_exact_match(source_obj, prefix):
continue
Expand All @@ -445,17 +462,20 @@ def _copy_multiple_objects(self, hook, source_objects, prefix):
file_name_postfix = source_obj.replace(base_path, "", 1)
destination_object = self.destination_object.rstrip("/") + "/" + file_name_postfix

self._copy_single_object(
uri = self._copy_single_object(
hook=hook, source_object=source_obj, destination_object=destination_object
)
if uri:
result_uris.append(uri)
return result_uris

def _check_exact_match(self, source_object: str, prefix: str) -> bool:
"""Check whether source_object's name matches the prefix according to the exact_match flag."""
if self.exact_match and (source_object != prefix or not source_object.endswith(prefix)):
return False
return True

def _copy_source_with_wildcard(self, hook, prefix):
def _copy_source_with_wildcard(self, hook, prefix) -> list[str]:
total_wildcards = prefix.count(WILDCARD)
if total_wildcards > 1:
error_msg = (
Expand All @@ -480,25 +500,30 @@ def _copy_source_with_wildcard(self, hook, prefix):
# remove previous line and uncomment the following:
# objects = self._ignore_existing_files(hook, prefix_, match_glob=match_glob, objects=objects)

result_uris: list[str] = []
for source_object in objects:
if self.destination_object is None:
destination_object = source_object
else:
destination_object = source_object.replace(prefix_, self.destination_object, 1)

self._copy_single_object(
uri = self._copy_single_object(
hook=hook, source_object=source_object, destination_object=destination_object
)
if uri:
result_uris.append(uri)
return result_uris

def _copy_single_object(self, hook, source_object, destination_object):
def _copy_single_object(self, hook, source_object, destination_object) -> str | None:
dest_bucket = self.destination_bucket
if self.is_older_than:
# Here we check if the given object is older than the given time
# If given, last_modified_time and maximum_modified_time is ignored
if hook.is_older_than(self.source_bucket, source_object, self.is_older_than):
self.log.info("Object is older than %s seconds ago", self.is_older_than)
else:
self.log.debug("Object is not older than %s seconds ago", self.is_older_than)
return
return None
elif self.last_modified_time and self.maximum_modified_time:
# check to see if object was modified between last_modified_time and
# maximum_modified_time
Expand All @@ -516,35 +541,37 @@ def _copy_single_object(self, hook, source_object, destination_object):
self.last_modified_time,
self.maximum_modified_time,
)
return
return None

elif self.last_modified_time is not None:
# Check to see if object was modified after last_modified_time
if hook.is_updated_after(self.source_bucket, source_object, self.last_modified_time):
self.log.info("Object has been modified after %s ", self.last_modified_time)
else:
self.log.debug("Object was not modified after %s ", self.last_modified_time)
return
return None
elif self.maximum_modified_time is not None:
# Check to see if object was modified before maximum_modified_time
if hook.is_updated_before(self.source_bucket, source_object, self.maximum_modified_time):
self.log.info("Object has been modified before %s ", self.maximum_modified_time)
else:
self.log.debug("Object was not modified before %s ", self.maximum_modified_time)
return
return None

self.log.info(
"Executing copy of gs://%s/%s to gs://%s/%s",
self.source_bucket,
source_object,
self.destination_bucket,
dest_bucket,
destination_object,
)
hook.rewrite(self.source_bucket, source_object, self.destination_bucket, destination_object)
hook.rewrite(self.source_bucket, source_object, dest_bucket, destination_object)

if self.move_object:
hook.delete(self.source_bucket, source_object)

return f"gs://{dest_bucket}/{destination_object}"

def get_openlineage_facets_on_complete(self, task_instance):
"""
Implement _on_complete because execute method does preprocessing on internals.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1000,3 +1000,65 @@ def test_get_openlineage_facets_on_complete(
assert all(element in inputs for element in lineage.inputs)
assert all(element in lineage.outputs for element in outputs)
assert all(element in outputs for element in lineage.outputs)

# Return value tests
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_returns_list_of_destination_uris_single_file(self, mock_hook):
mock_hook.return_value.list.return_value = [SOURCE_OBJECT_NO_WILDCARD]
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECT_NO_WILDCARD,
destination_bucket=DESTINATION_BUCKET,
destination_object=DESTINATION_OBJECT_PREFIX + "/",
exact_match=True,
)
result = operator.execute(None)
expected_uri = f"gs://{DESTINATION_BUCKET}/{DESTINATION_OBJECT_PREFIX}/{SOURCE_OBJECT_NO_WILDCARD}"
assert result == [expected_uri]

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_returns_list_of_destination_uris_multiple_files(self, mock_hook):
mock_hook.return_value.list.return_value = SOURCE_OBJECTS_LIST
with pytest.warns(AirflowProviderDeprecationWarning, match="Usage of wildcard"):
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECT_WILDCARD_FILENAME,
destination_bucket=DESTINATION_BUCKET,
destination_object=DESTINATION_OBJECT_PREFIX,
)
result = operator.execute(None)
expected = [
f"gs://{DESTINATION_BUCKET}/foo/bar/file1.txt",
f"gs://{DESTINATION_BUCKET}/foo/bar/file2.txt",
f"gs://{DESTINATION_BUCKET}/foo/bar/file3.json",
]
assert sorted(result) == sorted(expected)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_returns_empty_list_when_no_files_copied(self, mock_hook):
mock_hook.return_value.is_updated_after.return_value = False
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECT_NO_WILDCARD,
destination_bucket=DESTINATION_BUCKET,
destination_object=SOURCE_OBJECT_NO_WILDCARD,
last_modified_time=MOD_TIME_1,
)
result = operator.execute(None)
assert result == []

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_excludes_directory_uri_from_return(self, mock_hook):
mock_hook.return_value.list.return_value = ["prefix/file.txt", "prefix/"]
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_objects=["prefix/"],
destination_bucket=DESTINATION_BUCKET,
destination_object="backup/",
)
result = operator.execute(None)
assert result == [f"gs://{DESTINATION_BUCKET}/backup/file.txt"]