diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 0c6d45bb50d69..8391fe040c5d4 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -237,7 +237,7 @@ def __init__( self.exact_match = exact_match self.match_glob = match_glob - def execute(self, context: Context): + def execute(self, context: Context) -> list[str]: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -274,13 +274,22 @@ def execute(self, context: Context): raise AirflowException("You can't have two empty strings inside source_object") # Iterate over the source_objects and do the copy + destination_uris_result: list[str] = [] for prefix in self.source_objects: # Check if prefix contains wildcard if WILDCARD in prefix: - self._copy_source_with_wildcard(hook=hook, prefix=prefix) + destination_uris_result.extend(self._copy_source_with_wildcard(hook=hook, prefix=prefix)) # Now search with prefix using provided delimiter if any else: - self._copy_source_without_wildcard(hook=hook, prefix=prefix) + destination_uris_result.extend(self._copy_source_without_wildcard(hook=hook, prefix=prefix)) + + # Deduplicate while preserving order. Same destination URI can appear when multiple + # source paths map to one file, e.g. source_objects=["src/foo.png", "src/data/"] with + # data/ containing foo.png yields backup/list/foo.png twice. + destination_uris_result = list(dict.fromkeys(destination_uris_result)) + + # Exclude directory-like URIs (ending with /) from return; copy behavior is unchanged. + return [u for u in destination_uris_result if not u.endswith("/")] def _ignore_existing_files(self, hook, prefix, **kwargs): # list all files in the Destination GCS bucket @@ -312,10 +321,12 @@ def _ignore_existing_files(self, hook, prefix, **kwargs): self.log.info("There are no new files to sync. Have a nice day!") return objects - def _copy_source_without_wildcard(self, hook, prefix): + def _copy_source_without_wildcard(self, hook, prefix) -> list[str]: """ List all files in source_objects, copy files to destination_object, and rename each source file. + :return: List of destination URIs (gs://bucket/object) for files that were copied. + For source_objects with no wildcard, this operator would first list all files in source_objects, using provided delimiter if any. Then copy files from source_objects to destination_object and rename each source @@ -399,32 +410,37 @@ def _copy_source_without_wildcard(self, hook, prefix): hook, prefix, objects=objects, delimiter=self.delimiter, match_glob=self.match_glob ) + result_uris: list[str] = [] # If objects is empty, and we have prefix, let's check if prefix is a blob # and copy directly if len(objects) == 0 and prefix: if hook.exists(self.source_bucket, prefix): - self._copy_single_object( + uri = self._copy_single_object( hook=hook, source_object=prefix, destination_object=self.destination_object ) + if uri: + result_uris.append(uri) elif self.source_object_required: msg = f"{prefix} does not exist in bucket {self.source_bucket}" self.log.warning(msg) raise AirflowException(msg) if len(objects) == 1 and objects[0][-1] != "/": - self._copy_file(hook=hook, source_object=objects[0]) + result_uris.extend(self._copy_file(hook=hook, source_object=objects[0])) elif len(objects): - self._copy_multiple_objects(hook=hook, source_objects=objects, prefix=prefix) + result_uris.extend(self._copy_multiple_objects(hook=hook, source_objects=objects, prefix=prefix)) + return result_uris - def _copy_file(self, hook, source_object): + def _copy_file(self, hook, source_object) -> list[str]: destination_object = self.destination_object or source_object if self.destination_object and self.destination_object[-1] == "/": file_name = source_object.split("/")[-1] destination_object += file_name - self._copy_single_object( + uri = self._copy_single_object( hook=hook, source_object=source_object, destination_object=destination_object ) + return [uri] if uri else [] - def _copy_multiple_objects(self, hook, source_objects, prefix): + def _copy_multiple_objects(self, hook, source_objects, prefix) -> list[str]: # Check whether the prefix is a root directory for all the rest of objects. _pref = prefix.rstrip("/") is_directory = prefix.endswith("/") or all( @@ -436,6 +452,7 @@ def _copy_multiple_objects(self, hook, source_objects, prefix): else: base_path = prefix[0 : prefix.rfind("/") + 1] if "/" in prefix else "" + result_uris: list[str] = [] for source_obj in source_objects: if not self._check_exact_match(source_obj, prefix): continue @@ -445,9 +462,12 @@ def _copy_multiple_objects(self, hook, source_objects, prefix): file_name_postfix = source_obj.replace(base_path, "", 1) destination_object = self.destination_object.rstrip("/") + "/" + file_name_postfix - self._copy_single_object( + uri = self._copy_single_object( hook=hook, source_object=source_obj, destination_object=destination_object ) + if uri: + result_uris.append(uri) + return result_uris def _check_exact_match(self, source_object: str, prefix: str) -> bool: """Check whether source_object's name matches the prefix according to the exact_match flag.""" @@ -455,7 +475,7 @@ def _check_exact_match(self, source_object: str, prefix: str) -> bool: return False return True - def _copy_source_with_wildcard(self, hook, prefix): + def _copy_source_with_wildcard(self, hook, prefix) -> list[str]: total_wildcards = prefix.count(WILDCARD) if total_wildcards > 1: error_msg = ( @@ -480,17 +500,22 @@ def _copy_source_with_wildcard(self, hook, prefix): # remove previous line and uncomment the following: # objects = self._ignore_existing_files(hook, prefix_, match_glob=match_glob, objects=objects) + result_uris: list[str] = [] for source_object in objects: if self.destination_object is None: destination_object = source_object else: destination_object = source_object.replace(prefix_, self.destination_object, 1) - self._copy_single_object( + uri = self._copy_single_object( hook=hook, source_object=source_object, destination_object=destination_object ) + if uri: + result_uris.append(uri) + return result_uris - def _copy_single_object(self, hook, source_object, destination_object): + def _copy_single_object(self, hook, source_object, destination_object) -> str | None: + dest_bucket = self.destination_bucket if self.is_older_than: # Here we check if the given object is older than the given time # If given, last_modified_time and maximum_modified_time is ignored @@ -498,7 +523,7 @@ def _copy_single_object(self, hook, source_object, destination_object): self.log.info("Object is older than %s seconds ago", self.is_older_than) else: self.log.debug("Object is not older than %s seconds ago", self.is_older_than) - return + return None elif self.last_modified_time and self.maximum_modified_time: # check to see if object was modified between last_modified_time and # maximum_modified_time @@ -516,7 +541,7 @@ def _copy_single_object(self, hook, source_object, destination_object): self.last_modified_time, self.maximum_modified_time, ) - return + return None elif self.last_modified_time is not None: # Check to see if object was modified after last_modified_time @@ -524,27 +549,29 @@ def _copy_single_object(self, hook, source_object, destination_object): self.log.info("Object has been modified after %s ", self.last_modified_time) else: self.log.debug("Object was not modified after %s ", self.last_modified_time) - return + return None elif self.maximum_modified_time is not None: # Check to see if object was modified before maximum_modified_time if hook.is_updated_before(self.source_bucket, source_object, self.maximum_modified_time): self.log.info("Object has been modified before %s ", self.maximum_modified_time) else: self.log.debug("Object was not modified before %s ", self.maximum_modified_time) - return + return None self.log.info( "Executing copy of gs://%s/%s to gs://%s/%s", self.source_bucket, source_object, - self.destination_bucket, + dest_bucket, destination_object, ) - hook.rewrite(self.source_bucket, source_object, self.destination_bucket, destination_object) + hook.rewrite(self.source_bucket, source_object, dest_bucket, destination_object) if self.move_object: hook.delete(self.source_bucket, source_object) + return f"gs://{dest_bucket}/{destination_object}" + def get_openlineage_facets_on_complete(self, task_instance): """ Implement _on_complete because execute method does preprocessing on internals. diff --git a/providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py b/providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py index bc347a84e1b32..45dcff4516bfe 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py @@ -1000,3 +1000,65 @@ def test_get_openlineage_facets_on_complete( assert all(element in inputs for element in lineage.inputs) assert all(element in lineage.outputs for element in outputs) assert all(element in outputs for element in lineage.outputs) + + # Return value tests + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_returns_list_of_destination_uris_single_file(self, mock_hook): + mock_hook.return_value.list.return_value = [SOURCE_OBJECT_NO_WILDCARD] + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=SOURCE_OBJECT_NO_WILDCARD, + destination_bucket=DESTINATION_BUCKET, + destination_object=DESTINATION_OBJECT_PREFIX + "/", + exact_match=True, + ) + result = operator.execute(None) + expected_uri = f"gs://{DESTINATION_BUCKET}/{DESTINATION_OBJECT_PREFIX}/{SOURCE_OBJECT_NO_WILDCARD}" + assert result == [expected_uri] + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_returns_list_of_destination_uris_multiple_files(self, mock_hook): + mock_hook.return_value.list.return_value = SOURCE_OBJECTS_LIST + with pytest.warns(AirflowProviderDeprecationWarning, match="Usage of wildcard"): + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=SOURCE_OBJECT_WILDCARD_FILENAME, + destination_bucket=DESTINATION_BUCKET, + destination_object=DESTINATION_OBJECT_PREFIX, + ) + result = operator.execute(None) + expected = [ + f"gs://{DESTINATION_BUCKET}/foo/bar/file1.txt", + f"gs://{DESTINATION_BUCKET}/foo/bar/file2.txt", + f"gs://{DESTINATION_BUCKET}/foo/bar/file3.json", + ] + assert sorted(result) == sorted(expected) + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_returns_empty_list_when_no_files_copied(self, mock_hook): + mock_hook.return_value.is_updated_after.return_value = False + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=SOURCE_OBJECT_NO_WILDCARD, + destination_bucket=DESTINATION_BUCKET, + destination_object=SOURCE_OBJECT_NO_WILDCARD, + last_modified_time=MOD_TIME_1, + ) + result = operator.execute(None) + assert result == [] + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook") + def test_execute_excludes_directory_uri_from_return(self, mock_hook): + mock_hook.return_value.list.return_value = ["prefix/file.txt", "prefix/"] + operator = GCSToGCSOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_objects=["prefix/"], + destination_bucket=DESTINATION_BUCKET, + destination_object="backup/", + ) + result = operator.execute(None) + assert result == [f"gs://{DESTINATION_BUCKET}/backup/file.txt"]