Skip to content

Commit 3dba0fa

Browse files
committed
Revert changes that had crept in from other PR
1 parent 4b6e5c0 commit 3dba0fa

File tree

3 files changed

+50
-87
lines changed

3 files changed

+50
-87
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@
5353
_metadata_with_prefix,
5454
_metadata_with_leader_aware_routing,
5555
)
56-
from google.cloud.spanner_v1._opentelemetry_tracing import (
57-
get_current_span,
58-
trace_call,
59-
)
6056
from google.cloud.spanner_v1.batch import Batch
6157
from google.cloud.spanner_v1.batch import MutationGroups
6258
from google.cloud.spanner_v1.keyset import KeySet
@@ -71,6 +67,7 @@
7167
SpannerGrpcTransport,
7268
)
7369
from google.cloud.spanner_v1.table import Table
70+
from google.cloud.spanner_v1._opentelemetry_tracing import get_current_span
7471

7572

7673
SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
@@ -698,7 +695,7 @@ def execute_partitioned_dml(
698695
)
699696

700697
def execute_pdml():
701-
def do_execute_pdml(session):
698+
with SessionCheckout(self._pool) as session:
702699
txn = api.begin_transaction(
703700
session=session.name, options=txn_options, metadata=metadata
704701
)
@@ -730,15 +727,6 @@ def do_execute_pdml(session):
730727

731728
return result_set.stats.row_count_lower_bound
732729

733-
with SessionCheckout(self._pool) as session:
734-
observability_options = getattr(self, "observability_options", None)
735-
with trace_call(
736-
"CloudSpanner.execute_pdml",
737-
session,
738-
observability_options=observability_options,
739-
):
740-
return do_execute_pdml(session)
741-
742730
return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
743731

744732
def session(self, labels=None, database_role=None):
@@ -893,26 +881,17 @@ def run_in_transaction(self, func, *args, **kw):
893881
# Sanity check: Is there a transaction already running?
894882
# If there is, then raise a red flag. Otherwise, mark that this one
895883
# is running.
896-
with SessionCheckout(self._pool) as session:
897-
observability_options = getattr(self, "observability_options", None)
898-
with trace_call(
899-
"CloudSpanner.Database.run_in_transaction",
900-
session,
901-
observability_options=observability_options,
902-
):
903-
# Sanity check: Is there a transaction already running?
904-
# If there is, then raise a red flag. Otherwise, mark that this one
905-
# is running.
906-
if getattr(self._local, "transaction_running", False):
907-
raise RuntimeError("Spanner does not support nested transactions.")
908-
self._local.transaction_running = True
909-
910-
# Check out a session and run the function in a transaction; once
911-
# done, flip the sanity check bit back.
912-
try:
913-
return session.run_in_transaction(func, *args, **kw)
914-
finally:
915-
self._local.transaction_running = False
884+
if getattr(self._local, "transaction_running", False):
885+
raise RuntimeError("Spanner does not support nested transactions.")
886+
self._local.transaction_running = True
887+
888+
# Check out a session and run the function in a transaction; once
889+
# done, flip the sanity check bit back.
890+
try:
891+
with SessionCheckout(self._pool) as session:
892+
return session.run_in_transaction(func, *args, **kw)
893+
finally:
894+
self._local.transaction_running = False
916895

917896
def restore(self, source):
918897
"""Restore from a backup to this database.
@@ -1189,7 +1168,7 @@ def __enter__(self):
11891168
current_span = get_current_span()
11901169
session = self._session = self._database._pool.get()
11911170
if current_span:
1192-
current_span.add_event("Using session", {"id": self._session.session_id})
1171+
current_span.add_event("Using session", {"id": session.session_id})
11931172
batch = self._batch = Batch(session)
11941173
if self._request_options.transaction_tag:
11951174
batch.transaction_tag = self._request_options.transaction_tag

google/cloud/spanner_v1/session.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -426,17 +426,23 @@ def run_in_transaction(self, func, *args, **kw):
426426
exclude_txn_from_change_streams = kw.pop(
427427
"exclude_txn_from_change_streams", None
428428
)
429-
430-
observability_options = getattr(self._database, "observability_options", None)
431429
attempts = 0
432430

433-
def __run_txn(txn, attempts):
431+
while True:
432+
if self._transaction is None:
433+
txn = self.transaction()
434+
txn.transaction_tag = transaction_tag
435+
txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams
436+
else:
437+
txn = self._transaction
438+
434439
try:
440+
attempts += 1
435441
return_value = func(txn, *args, **kw)
436442
except Aborted as exc:
437443
del self._transaction
438444
_delay_until_retry(exc, deadline, attempts)
439-
return None, False
445+
continue
440446
except GoogleAPICallError:
441447
del self._transaction
442448
raise
@@ -462,28 +468,7 @@ def __run_txn(txn, attempts):
462468
"CommitStats: {}".format(txn.commit_stats),
463469
extra={"commit_stats": txn.commit_stats},
464470
)
465-
return return_value, True
466-
467-
while True:
468-
if self._transaction is None:
469-
with trace_call(
470-
"CloudSpanner.ReadWriteTransaction", self, observability_options
471-
):
472-
txn = self.transaction()
473-
txn.transaction_tag = transaction_tag
474-
txn.exclude_txn_from_change_streams = (
475-
exclude_txn_from_change_streams
476-
)
477-
return_value, completed = __run_txn(txn, attempts)
478-
if completed:
479-
return return_value
480-
else:
481-
txn = self._transaction
482-
return_value, completed = __run_txn(txn, attempts)
483-
if completed:
484-
return return_value
485-
486-
attempts += 1
471+
return return_value
487472

488473

489474
# Rational: this function factors out complex shared deadline / retry

google/cloud/spanner_v1/transaction.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,31 @@ def commit(
248248
raise ValueError("Transaction is not begun")
249249

250250
database = self._session._database
251+
api = database.spanner_api
252+
metadata = _metadata_with_prefix(database.name)
253+
if database._route_to_leader_enabled:
254+
metadata.append(
255+
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
256+
)
257+
258+
if request_options is None:
259+
request_options = RequestOptions()
260+
elif type(request_options) is dict:
261+
request_options = RequestOptions(request_options)
262+
if self.transaction_tag is not None:
263+
request_options.transaction_tag = self.transaction_tag
264+
# Request tags are not supported for commit requests.
265+
request_options.request_tag = None
266+
267+
request = CommitRequest(
268+
session=self._session.name,
269+
mutations=self._mutations,
270+
transaction_id=self._transaction_id,
271+
return_commit_stats=return_commit_stats,
272+
max_commit_delay=max_commit_delay,
273+
request_options=request_options,
274+
)
275+
251276
trace_attributes = {"num_mutations": len(self._mutations)}
252277
observability_options = getattr(database, "observability_options", None)
253278
with trace_call(
@@ -262,32 +287,6 @@ def commit(
262287
if span:
263288
span.add_event("Starting Commit")
264289

265-
api = database.spanner_api
266-
metadata = _metadata_with_prefix(database.name)
267-
if database._route_to_leader_enabled:
268-
metadata.append(
269-
_metadata_with_leader_aware_routing(
270-
database._route_to_leader_enabled
271-
)
272-
)
273-
274-
if request_options is None:
275-
request_options = RequestOptions()
276-
elif type(request_options) is dict:
277-
request_options = RequestOptions(request_options)
278-
if self.transaction_tag is not None:
279-
request_options.transaction_tag = self.transaction_tag
280-
# Request tags are not supported for commit requests.
281-
request_options.request_tag = None
282-
283-
request = CommitRequest(
284-
session=self._session.name,
285-
mutations=self._mutations,
286-
transaction_id=self._transaction_id,
287-
return_commit_stats=return_commit_stats,
288-
max_commit_delay=max_commit_delay,
289-
request_options=request_options,
290-
)
291290
method = functools.partial(
292291
api.commit,
293292
request=request,

0 commit comments

Comments
 (0)