Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
def execute(self, context: Context) -> list[str]:
gdrive_hook = GoogleDriveHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand All @@ -100,6 +100,11 @@ def execute(self, context: Context):
) as file:
gdrive_hook.download_file(file_id=file_metadata["id"], file_handle=file)

gcs_uri = f"gs://{self.bucket_name}/{self.object_name}"
result = [gcs_uri]

return result

def dry_run(self):
"""Perform a dry run of the operator."""
return None
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class GoogleSheetsToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param return_gcs_uris: If True, returns full GCS URIs (e.g., ``gs://bucket/path/file.csv``).
If False (default), returns object names only (e.g., ``path/to/file.csv``).
Default will change to True in a future release.
"""

template_fields: Sequence[str] = (
Expand All @@ -72,6 +75,7 @@ def __init__(
destination_path: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
return_gcs_uris: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -81,6 +85,16 @@ def __init__(
self.destination_bucket = destination_bucket
self.destination_path = destination_path
self.impersonation_chain = impersonation_chain
self.return_gcs_uris = return_gcs_uris
if not self.return_gcs_uris:
import warnings

warnings.warn(
"The default value of return_gcs_uris will change from False to True in a future release. "
"Please set return_gcs_uris explicitly to avoid this warning.",
FutureWarning,
stacklevel=2,
)

def _upload_data(
self,
Expand Down Expand Up @@ -110,7 +124,7 @@ def _upload_data(
)
return dest_file_name

def execute(self, context: Context):
def execute(self, context: Context) -> list[str]:
sheet_hook = GSheetsHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand All @@ -128,7 +142,11 @@ def execute(self, context: Context):
for sheet_range in sheet_titles:
data = sheet_hook.get_values(spreadsheet_id=self.spreadsheet_id, range_=sheet_range)
gcs_path_to_file = self._upload_data(gcs_hook, sheet_hook, sheet_range, data)
destination_array.append(gcs_path_to_file)
if self.return_gcs_uris:
gcs_uri = f"gs://{self.destination_bucket}/{gcs_path_to_file}"
destination_array.append(gcs_uri)
else:
destination_array.append(gcs_path_to_file)

context["ti"].xcom_push(key="destination_objects", value=destination_array)
return destination_array
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_execute(self, mock_gdrive_hook, mock_gcs_hook):
meta = {"id": "123xyz"}
mock_gdrive_hook.return_value.get_file_id.return_value = meta

op.execute(context)
result = op.execute(context)
mock_gdrive_hook.return_value.get_file_id.assert_called_once_with(
folder_id=FOLDER_ID, file_name=FILE_NAME, drive_id=DRIVE_ID
)
Expand All @@ -61,4 +61,6 @@ def test_execute(self, mock_gdrive_hook, mock_gcs_hook):
bucket_name=BUCKET, object_name=OBJECT
)

# Assert list with GCS URI is returned
assert result == [f"gs://{BUCKET}/{OBJECT}"]
assert op.dry_run() is None
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def test_upload_data(self, mock_tempfile, mock_writer):
@mock.patch(
"airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator._upload_data"
)
def test_execute(self, mock_upload_data, mock_sheet_hook, mock_gcs_hook):
def test_execute_with_return_gcs_uris_true(
self,
mock_upload_data,
mock_sheet_hook,
mock_gcs_hook,
):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti}
data = ["data1", "data2"]
Expand All @@ -97,8 +102,9 @@ def test_execute(self, mock_upload_data, mock_sheet_hook, mock_gcs_hook):
destination_path=PATH,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
return_gcs_uris=True,
)
op.execute(mock_context)
result = op.execute(mock_context)

mock_sheet_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -124,4 +130,67 @@ def test_execute(self, mock_upload_data, mock_sheet_hook, mock_gcs_hook):
actual_call_count = mock_upload_data.call_count
assert len(RANGES) == actual_call_count

expected_uris = [f"gs://{BUCKET}/{PATH}", f"gs://{BUCKET}/{PATH}"]
mock_ti.xcom_push.assert_called_once_with(key="destination_objects", value=expected_uris)
assert result == expected_uris

@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GSheetsHook")
@mock.patch(
"airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator._upload_data"
)
def test_execute_with_return_gcs_uris_false(
self,
mock_upload_data,
mock_sheet_hook,
mock_gcs_hook,
):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti}
data = ["data1", "data2"]
mock_sheet_hook.return_value.get_sheet_titles.return_value = RANGES
mock_sheet_hook.return_value.get_values.side_effect = data
mock_upload_data.side_effect = [PATH, PATH]
op = GoogleSheetsToGCSOperator(
task_id="test_task",
spreadsheet_id=SPREADSHEET_ID,
destination_bucket=BUCKET,
sheet_filter=FILTER,
destination_path=PATH,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
return_gcs_uris=False,
)
result = op.execute(mock_context)
mock_ti.xcom_push.assert_called_once_with(key="destination_objects", value=[PATH, PATH])
assert result == [PATH, PATH]

@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GSheetsHook")
@mock.patch(
"airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator._upload_data"
)
def test_execute_with_return_gcs_uris_default(
self,
mock_upload_data,
mock_sheet_hook,
mock_gcs_hook,
):
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti}
data = ["data1"]
mock_sheet_hook.return_value.get_sheet_titles.return_value = ["single_range"]
mock_sheet_hook.return_value.get_values.side_effect = data
mock_upload_data.side_effect = [PATH]
op = GoogleSheetsToGCSOperator(
task_id="test_task",
spreadsheet_id=SPREADSHEET_ID,
destination_bucket=BUCKET,
sheet_filter=FILTER,
destination_path=PATH,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
result = op.execute(mock_context)
mock_ti.xcom_push.assert_called_once_with(key="destination_objects", value=[PATH])
assert result == [PATH]