diff --git a/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py b/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py index 8f7fbd26cf151..041f13ec7c079 100644 --- a/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py +++ b/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py @@ -62,6 +62,8 @@ class GoogleAdsToGcsOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param api_version: Optional Google Ads API version to use. + :param unwrap_single: If True, return the GCS URI as a string instead of a list. + Defaults to False for backward compatibility. """ template_fields: Sequence[str] = ( @@ -86,6 +88,7 @@ def __init__( gzip: bool = False, impersonation_chain: str | Sequence[str] | None = None, api_version: str | None = None, + unwrap_single: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -99,8 +102,15 @@ def __init__( self.gzip = gzip self.impersonation_chain = impersonation_chain self.api_version = api_version + self.unwrap_single = unwrap_single - def execute(self, context: Context) -> None: + def execute(self, context: Context) -> str | list[str]: + """ + Fetch data from Google Ads API and upload CSV to GCS. + + :return: The destination GCS URI of the uploaded CSV file. + Returns a string if unwrap_single is True, otherwise returns a list containing the URI. + """ service = GoogleAdsHook( gcp_conn_id=self.gcp_conn_id, google_ads_conn_id=self.google_ads_conn_id, @@ -129,3 +139,8 @@ def execute(self, context: Context) -> None: gzip=self.gzip, ) self.log.info("%s uploaded to GCS", self.obj) + + gcs_uri = f"gs://{self.bucket}/{self.obj}" + if self.unwrap_single: + return gcs_uri + return [gcs_uri] diff --git a/providers/google/tests/unit/google/ads/transfers/test_ads_to_gcs.py b/providers/google/tests/unit/google/ads/transfers/test_ads_to_gcs.py index a5d83ccb308f2..81505e521b913 100644 --- a/providers/google/tests/unit/google/ads/transfers/test_ads_to_gcs.py +++ b/providers/google/tests/unit/google/ads/transfers/test_ads_to_gcs.py @@ -49,7 +49,7 @@ def test_execute(self, mock_gcs_hook, mock_ads_hook): impersonation_chain=IMPERSONATION_CHAIN, api_version=api_version, ) - op.execute({}) + result = op.execute({}) mock_ads_hook.assert_called_once_with( gcp_conn_id=gcp_conn_id, google_ads_conn_id=google_ads_conn_id, @@ -63,3 +63,39 @@ def test_execute(self, mock_gcs_hook, mock_ads_hook): mock_gcs_hook.return_value.upload.assert_called_once_with( bucket_name=BUCKET, object_name=GCS_OBJ_PATH, filename=mock.ANY, gzip=False ) + assert result == [f"gs://{BUCKET}/{GCS_OBJ_PATH}"] + assert isinstance(result, list) + assert len(result) == 1 + + @mock.patch("airflow.providers.google.ads.transfers.ads_to_gcs.GoogleAdsHook") + @mock.patch("airflow.providers.google.ads.transfers.ads_to_gcs.GCSHook") + def test_execute_with_unwrap_single(self, mock_gcs_hook, mock_ads_hook): + op = GoogleAdsToGcsOperator( + gcp_conn_id=gcp_conn_id, + google_ads_conn_id=google_ads_conn_id, + client_ids=CLIENT_IDS, + query=QUERY, + attributes=FIELDS_TO_EXTRACT, + obj=GCS_OBJ_PATH, + bucket=BUCKET, + task_id="run_operator", + impersonation_chain=IMPERSONATION_CHAIN, + api_version=api_version, + unwrap_single=True, + ) + result = op.execute({}) + mock_ads_hook.assert_called_once_with( + gcp_conn_id=gcp_conn_id, + google_ads_conn_id=google_ads_conn_id, + api_version=api_version, + ) + mock_ads_hook.return_value.search.assert_called_once_with(client_ids=CLIENT_IDS, query=QUERY) + mock_gcs_hook.assert_called_once_with( + gcp_conn_id=gcp_conn_id, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_gcs_hook.return_value.upload.assert_called_once_with( + bucket_name=BUCKET, object_name=GCS_OBJ_PATH, filename=mock.ANY, gzip=False + ) + assert result == f"gs://{BUCKET}/{GCS_OBJ_PATH}" + assert isinstance(result, str)