Skip to content

Commit

Permalink
feat(x-goog-spanner-request-id): implement request_id generation and …
Browse files Browse the repository at this point in the history
…propagation

Generates a request_id that is then injected inside metadata
that's sent over to the Cloud Spanner backend.

Officially inject the first set of x-goog-spanner-request-id values into header metadata
Add request-id interceptor to use in asserting tests
Wrap Snapshot methods with x-goog-request-id metadata injector
Setup scaffolding for XGoogRequestIdHeader checks
Wire up XGoogSpannerRequestIdInterceptor for TestDatabase checks
Inject header in more Session using spots plus more tests
Base for tests with retries on abort
More plumbing for Transaction and Database
Update unit tests for Transaction
Wrap more in Transaction + update tests
Update tests
Plumb in more tests
Update TestDatabase

Fixes #1261
  • Loading branch information
odeke-em committed Dec 20, 2024
1 parent f2483e1 commit 65757b5
Show file tree
Hide file tree
Showing 17 changed files with 1,168 additions and 128 deletions.
9 changes: 9 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1._helpers import AtomicCounter

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
Expand Down Expand Up @@ -147,6 +148,8 @@ class Client(ClientWithProject):
SCOPE = (SPANNER_ADMIN_SCOPE,)
"""The scopes required for Google Cloud Spanner."""

NTH_CLIENT = AtomicCounter()

def __init__(
self,
project=None,
Expand Down Expand Up @@ -199,6 +202,12 @@ def __init__(
self._route_to_leader_enabled = route_to_leader_enabled
self._directed_read_options = directed_read_options
self._observability_options = observability_options
self._nth_client_id = Client.NTH_CLIENT.increment()
self._nth_request = AtomicCounter()

@property
def _next_nth_request(self):
return self._nth_request.increment()

@property
def credentials(self):
Expand Down
67 changes: 61 additions & 6 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import (
AtomicCounter,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_metadata_with_request_id,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
Expand Down Expand Up @@ -149,6 +151,9 @@ class Database(object):

_spanner_api: SpannerClient = None

__transport_lock = threading.Lock()
__transports_to_channel_id = dict()

def __init__(
self,
database_id,
Expand Down Expand Up @@ -443,6 +448,31 @@ def spanner_api(self):
)
return self._spanner_api

@property
def _channel_id(self):
"""
Helper to retrieve the associated channelID for the spanner_api.
This property is paramount to x-goog-spanner-request-id.
"""
with self.__transport_lock:
api = self.spanner_api
channel_id = self.__transports_to_channel_id.get(api._transport, None)
if channel_id is None:
channel_id = len(self.__transports_to_channel_id) + 1
self.__transports_to_channel_id[api._transport] = channel_id

return channel_id

def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
client_id = self._nth_client_id
return _metadata_with_request_id(
self._nth_client_id,
self._channel_id,
nth_request,
nth_attempt,
prior_metadata,
)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
Expand Down Expand Up @@ -698,10 +728,20 @@ def execute_partitioned_dml(
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
)

# Attempt will be incremented inside _restart_on_unavailable.
begin_txn_nth_request = self._next_nth_request
begin_txn_attempt = AtomicCounter(1)
partial_nth_request = self._next_nth_request
partial_attempt = AtomicCounter(0)

def execute_pdml():
with SessionCheckout(self._pool) as session:
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
session=session.name,
options=txn_options,
metadata=self.metadata_with_request_id(
begin_txn_nth_request, begin_txn_attempt.value, metadata
),
)

txn_selector = TransactionSelector(id=txn.id)
Expand All @@ -714,17 +754,24 @@ def execute_pdml():
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)

def wrapped_method(*args, **kwargs):
partial_attempt.increment()
method = functools.partial(
api.execute_streaming_sql,
metadata=self.metadata_with_request_id(
partial_nth_request, partial_attempt.value, metadata
),
)
return method(*args, **kwargs)

iterator = _restart_on_unavailable(
method=method,
method=wrapped_method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
attempt=begin_txn_attempt,
)

result_set = StreamedResultSet(iterator)
Expand All @@ -734,6 +781,14 @@ def execute_pdml():

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

@property
def _next_nth_request(self):
return self._instance._client._next_nth_request

@property
def _nth_client_id(self):
return self._instance._client._nth_client_id

def session(self, labels=None, database_role=None):
"""Factory to create a session for this database.
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def database(
proto_descriptors=proto_descriptors,
)
else:
print("enabled interceptors")
return TestDatabase(
database_id,
self,
Expand Down
6 changes: 5 additions & 1 deletion google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def bind(self, database):
"CloudSpanner.FixedPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
attempt = 1
returned_session_count = 0
while not self._sessions.full():
request.session_count = requested_session_count - self._sessions.qsize()
Expand All @@ -251,9 +252,12 @@ def bind(self, database):
f"Creating {request.session_count} sessions",
span_event_attributes,
)
all_metadata = database.metadata_with_request_id(
database._next_nth_request, attempt, metadata
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
metadata=all_metadata,
)

add_span_event(
Expand Down
21 changes: 16 additions & 5 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def exists(self):
current_span, "Checking if Session exists", {"session.id": self._session_id}
)

api = self._database.spanner_api
database = self._database
api = database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
Expand All @@ -202,12 +203,16 @@ def exists(self):
)
)

all_metadata = database.metadata_with_request_id(
database._next_nth_request, 1, metadata
)

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.GetSession", self, observability_options=observability_options
) as span:
try:
api.get_session(name=self.name, metadata=metadata)
api.get_session(name=self.name, metadata=all_metadata)
if span:
span.set_attribute("session_found", True)
except NotFound:
Expand Down Expand Up @@ -237,8 +242,11 @@ def delete(self):
current_span, "Deleting Session", {"session.id": self._session_id}
)

api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
database = self._database
api = database.spanner_api
metadata = database.metadata_with_request_id(
database._next_nth_request, 1, _metadata_with_prefix(database.name)
)
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.DeleteSession",
Expand All @@ -259,7 +267,10 @@ def ping(self):
if self._session_id is None:
raise ValueError("Session ID not set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
database = self._database
metadata = database.metadata_with_request_id(
database._next_nth_request, 1, _metadata_with_prefix(database.name)
)
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
api.execute_sql(request=request, metadata=metadata)
self._last_use_time = datetime.now()
Expand Down
Loading

0 comments on commit 65757b5

Please sign in to comment.