Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections.abc import Sequence
from typing import Any

from google.api_core.client_options import ClientOptions
from googleapiclient.discovery import build

from airflow.exceptions import AirflowException
Expand All @@ -44,20 +45,24 @@ class GSheetsHook(GoogleBaseHook):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:param api_endpoint: Optional. Custom API endpoint, i.e: regional or private endpoint.
This can be used to target private VPC or restricted access endpoints.
"""

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
api_version: str = "v4",
impersonation_chain: str | Sequence[str] | None = None,
api_endpoint: str | None = None,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self.api_endpoint = api_endpoint
self._conn = None

def get_conn(self) -> Any:
Expand All @@ -68,7 +73,16 @@ def get_conn(self) -> Any:
"""
if not self._conn:
http_authorized = self._authorize()
self._conn = build("sheets", self.api_version, http=http_authorized, cache_discovery=False)
client_options = None
if self.api_endpoint:
client_options = ClientOptions(api_endpoint=self.api_endpoint)
self._conn = build(
"sheets",
self.api_version,
http=http_authorized,
cache_discovery=False,
client_options=client_options,
)

return self._conn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class GoogleSheetsCreateSpreadsheetOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param api_endpoint: Optional. Custom API endpoint, e.g: private.googleapis.com.
This can be used to target private VPC or restricted access endpoints.
"""

template_fields: Sequence[str] = (
Expand All @@ -55,17 +57,20 @@ def __init__(
spreadsheet: dict[str, Any],
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
api_endpoint: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.gcp_conn_id = gcp_conn_id
self.spreadsheet = spreadsheet
self.impersonation_chain = impersonation_chain
self.api_endpoint = api_endpoint

def execute(self, context: Any) -> dict[str, Any]:
hook = GSheetsHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
api_endpoint=self.api_endpoint,
)
spreadsheet = hook.create_spreadsheet(spreadsheet=self.spreadsheet)
context["task_instance"].xcom_push(key="spreadsheet_id", value=spreadsheet["spreadsheetId"])
Expand Down
22 changes: 19 additions & 3 deletions providers/google/tests/unit/google/suite/hooks/test_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,27 @@ def setup_method(self):
@mock.patch("airflow.providers.google.suite.hooks.sheets.build")
def test_gsheets_client_creation(self, mock_build, mock_authorize):
result = self.hook.get_conn()
mock_build.assert_called_once_with(
"sheets", "v4", http=mock_authorize.return_value, cache_discovery=False
)
mock_build.assert_called_once()
args, kwargs = mock_build.call_args
assert kwargs["http"] == mock_authorize.return_value
assert kwargs["cache_discovery"] is False
assert mock_build.return_value == result

@mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook._authorize")
@mock.patch("airflow.providers.google.suite.hooks.sheets.build")
def test_gsheets_hook_custom_endpoint(self, mock_build, mock_authorize):
self.hook.api_endpoint = "https://private.googleapis.com"
self.hook.get_conn()
mock_build.assert_called_once()
_, kwargs = mock_build.call_args
client_options = kwargs.get("client_options")
if client_options is None:
api_endpoint = None
else:
api_endpoint = getattr(client_options, "api_endpoint", None)

assert api_endpoint == "https://private.googleapis.com"

@mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn")
def test_get_values(self, get_conn):
get_method = get_conn.return_value.spreadsheets.return_value.values.return_value.get
Expand Down
Loading