Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class GoogleAdsToGcsOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param api_version: Optional Google Ads API version to use.
:param unwrap_single: If True, return the GCS URI as a string instead of a list.
Defaults to False for backward compatibility.
"""

template_fields: Sequence[str] = (
Expand All @@ -86,6 +88,7 @@ def __init__(
gzip: bool = False,
impersonation_chain: str | Sequence[str] | None = None,
api_version: str | None = None,
unwrap_single: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -99,8 +102,15 @@ def __init__(
self.gzip = gzip
self.impersonation_chain = impersonation_chain
self.api_version = api_version
self.unwrap_single = unwrap_single

def execute(self, context: Context) -> None:
def execute(self, context: Context) -> str | list[str]:
"""
Fetch data from Google Ads API and upload CSV to GCS.

:return: The destination GCS URI of the uploaded CSV file.
Returns a string if unwrap_single is True, otherwise returns a list containing the URI.
"""
service = GoogleAdsHook(
gcp_conn_id=self.gcp_conn_id,
google_ads_conn_id=self.google_ads_conn_id,
Expand Down Expand Up @@ -129,3 +139,8 @@ def execute(self, context: Context) -> None:
gzip=self.gzip,
)
self.log.info("%s uploaded to GCS", self.obj)

gcs_uri = f"gs://{self.bucket}/{self.obj}"
if self.unwrap_single:
return gcs_uri
return [gcs_uri]
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_execute(self, mock_gcs_hook, mock_ads_hook):
impersonation_chain=IMPERSONATION_CHAIN,
api_version=api_version,
)
op.execute({})
result = op.execute({})
mock_ads_hook.assert_called_once_with(
gcp_conn_id=gcp_conn_id,
google_ads_conn_id=google_ads_conn_id,
Expand All @@ -63,3 +63,39 @@ def test_execute(self, mock_gcs_hook, mock_ads_hook):
mock_gcs_hook.return_value.upload.assert_called_once_with(
bucket_name=BUCKET, object_name=GCS_OBJ_PATH, filename=mock.ANY, gzip=False
)
assert result == [f"gs://{BUCKET}/{GCS_OBJ_PATH}"]
assert isinstance(result, list)
assert len(result) == 1

@mock.patch("airflow.providers.google.ads.transfers.ads_to_gcs.GoogleAdsHook")
@mock.patch("airflow.providers.google.ads.transfers.ads_to_gcs.GCSHook")
def test_execute_with_unwrap_single(self, mock_gcs_hook, mock_ads_hook):
op = GoogleAdsToGcsOperator(
gcp_conn_id=gcp_conn_id,
google_ads_conn_id=google_ads_conn_id,
client_ids=CLIENT_IDS,
query=QUERY,
attributes=FIELDS_TO_EXTRACT,
obj=GCS_OBJ_PATH,
bucket=BUCKET,
task_id="run_operator",
impersonation_chain=IMPERSONATION_CHAIN,
api_version=api_version,
unwrap_single=True,
)
result = op.execute({})
mock_ads_hook.assert_called_once_with(
gcp_conn_id=gcp_conn_id,
google_ads_conn_id=google_ads_conn_id,
api_version=api_version,
)
mock_ads_hook.return_value.search.assert_called_once_with(client_ids=CLIENT_IDS, query=QUERY)
mock_gcs_hook.assert_called_once_with(
gcp_conn_id=gcp_conn_id,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_gcs_hook.return_value.upload.assert_called_once_with(
bucket_name=BUCKET, object_name=GCS_OBJ_PATH, filename=mock.ANY, gzip=False
)
assert result == f"gs://{BUCKET}/{GCS_OBJ_PATH}"
assert isinstance(result, str)