Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

session.run_in_transaction returns the callback's return value. #3753

Merged
merged 2 commits into from
Aug 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions spanner/google/cloud/spanner/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,9 @@ def run_in_transaction(self, func, *args, **kw):
If passed, "timeout_secs" will be removed and used to
override the default timeout.

:rtype: :class:`datetime.datetime`
:returns: timestamp of committed transaction
:rtype: Any
:returns: The return value of ``func``.

:raises Exception:
reraises any non-ABORT execptions raised by ``func``.
"""
Expand All @@ -284,7 +285,7 @@ def run_in_transaction(self, func, *args, **kw):
if txn._transaction_id is None:
txn.begin()
try:
func(txn, *args, **kw)
return_value = func(txn, *args, **kw)
except GaxError as exc:
_delay_until_retry(exc, deadline)
del self._transaction
Expand All @@ -299,8 +300,7 @@ def run_in_transaction(self, func, *args, **kw):
_delay_until_retry(exc, deadline)
del self._transaction
else:
committed = txn.committed
return committed
return return_value


# pylint: disable=misplaced-bare-raise
Expand Down
4 changes: 2 additions & 2 deletions spanner/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.cloud.spanner import __version__


def _make_credentials():
def _make_credentials(): # pragma: NO COVER
import google.auth.credentials

class _CredentialsWithScopes(
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(self, scopes=(), source=None):
self._scopes = scopes
self._source = source

def requires_scopes(self):
def requires_scopes(self): # pragma: NO COVER
return True

def with_scopes(self, scopes):
Expand Down
24 changes: 9 additions & 15 deletions spanner/tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,16 +513,16 @@ def test_run_in_transaction_w_args_w_kwargs_wo_abort(self):
def unit_of_work(txn, *args, **kw):
called_with.append((txn, args, kw))
txn.insert(TABLE_NAME, COLUMNS, VALUES)
return 42

committed = session.run_in_transaction(
return_value = session.run_in_transaction(
unit_of_work, 'abc', some_arg='def')

self.assertEqual(committed, now)
self.assertIsNone(session._transaction)
self.assertEqual(len(called_with), 1)
txn, args, kw = called_with[0]
self.assertIsInstance(txn, Transaction)
self.assertEqual(txn.committed, committed)
self.assertEqual(return_value, 42)
self.assertEqual(args, ('abc',))
self.assertEqual(kw, {'some_arg': 'def'})

Expand Down Expand Up @@ -561,18 +561,15 @@ def test_run_in_transaction_w_abort_no_retry_metadata(self):
def unit_of_work(txn, *args, **kw):
called_with.append((txn, args, kw))
txn.insert(TABLE_NAME, COLUMNS, VALUES)
return 'answer'

committed = session.run_in_transaction(
return_value = session.run_in_transaction(
unit_of_work, 'abc', some_arg='def')

self.assertEqual(committed, now)
self.assertEqual(len(called_with), 2)
for index, (txn, args, kw) in enumerate(called_with):
self.assertIsInstance(txn, Transaction)
if index == 1:
self.assertEqual(txn.committed, committed)
else:
self.assertIsNone(txn.committed)
self.assertEqual(return_value, 'answer')
self.assertEqual(args, ('abc',))
self.assertEqual(kw, {'some_arg': 'def'})

Expand Down Expand Up @@ -621,17 +618,15 @@ def unit_of_work(txn, *args, **kw):
time_module = _FauxTimeModule()

with _Monkey(MUT, time=time_module):
committed = session.run_in_transaction(
unit_of_work, 'abc', some_arg='def')
session.run_in_transaction(unit_of_work, 'abc', some_arg='def')

self.assertEqual(time_module._slept,
RETRY_SECONDS + RETRY_NANOS / 1.0e9)
self.assertEqual(committed, now)
self.assertEqual(len(called_with), 2)
for index, (txn, args, kw) in enumerate(called_with):
self.assertIsInstance(txn, Transaction)
if index == 1:
self.assertEqual(txn.committed, committed)
self.assertEqual(txn.committed, now)
else:
self.assertIsNone(txn.committed)
self.assertEqual(args, ('abc',))
Expand Down Expand Up @@ -688,9 +683,8 @@ def unit_of_work(txn, *args, **kw):
time_module = _FauxTimeModule()

with _Monkey(MUT, time=time_module):
committed = session.run_in_transaction(unit_of_work)
session.run_in_transaction(unit_of_work)

self.assertEqual(committed, now)
self.assertEqual(time_module._slept,
RETRY_SECONDS + RETRY_NANOS / 1.0e9)
self.assertEqual(len(called_with), 2)
Expand Down