From a25b2e44ae423555ae01c62a50594158fb09e83e Mon Sep 17 00:00:00 2001 From: Luke Sneeringer Date: Tue, 8 Aug 2017 10:05:34 -0700 Subject: [PATCH] session.run_in_transaction returns the callback's return value. (#3753) --- spanner/google/cloud/spanner/session.py | 10 +++++----- spanner/tests/unit/test_database.py | 4 ++-- spanner/tests/unit/test_session.py | 24 +++++++++--------------- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index 953ab62993ccd..04fcacea38ee8 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -268,8 +268,9 @@ def run_in_transaction(self, func, *args, **kw): If passed, "timeout_secs" will be removed and used to override the default timeout. - :rtype: :class:`datetime.datetime` - :returns: timestamp of committed transaction + :rtype: Any + :returns: The return value of ``func``. + :raises Exception: reraises any non-ABORT execptions raised by ``func``. """ @@ -284,7 +285,7 @@ def run_in_transaction(self, func, *args, **kw): if txn._transaction_id is None: txn.begin() try: - func(txn, *args, **kw) + return_value = func(txn, *args, **kw) except GaxError as exc: _delay_until_retry(exc, deadline) del self._transaction @@ -299,8 +300,7 @@ def run_in_transaction(self, func, *args, **kw): _delay_until_retry(exc, deadline) del self._transaction else: - committed = txn.committed - return committed + return return_value # pylint: disable=misplaced-bare-raise diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index ec94e0198c777..40e10ec971a99 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -22,7 +22,7 @@ from google.cloud.spanner import __version__ -def _make_credentials(): +def _make_credentials(): # pragma: NO COVER import google.auth.credentials class _CredentialsWithScopes( @@ -223,7 +223,7 @@ def __init__(self, scopes=(), source=None): self._scopes = scopes self._source = source - def requires_scopes(self): + def requires_scopes(self): # pragma: NO COVER return True def with_scopes(self, scopes): diff --git a/spanner/tests/unit/test_session.py b/spanner/tests/unit/test_session.py index 100555c8e49f8..826369079d29e 100644 --- a/spanner/tests/unit/test_session.py +++ b/spanner/tests/unit/test_session.py @@ -513,16 +513,16 @@ def test_run_in_transaction_w_args_w_kwargs_wo_abort(self): def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) + return 42 - committed = session.run_in_transaction( + return_value = session.run_in_transaction( unit_of_work, 'abc', some_arg='def') - self.assertEqual(committed, now) self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) - self.assertEqual(txn.committed, committed) + self.assertEqual(return_value, 42) self.assertEqual(args, ('abc',)) self.assertEqual(kw, {'some_arg': 'def'}) @@ -561,18 +561,15 @@ def test_run_in_transaction_w_abort_no_retry_metadata(self): def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) + return 'answer' - committed = session.run_in_transaction( + return_value = session.run_in_transaction( unit_of_work, 'abc', some_arg='def') - self.assertEqual(committed, now) self.assertEqual(len(called_with), 2) for index, (txn, args, kw) in enumerate(called_with): self.assertIsInstance(txn, Transaction) - if index == 1: - self.assertEqual(txn.committed, committed) - else: - self.assertIsNone(txn.committed) + self.assertEqual(return_value, 'answer') self.assertEqual(args, ('abc',)) self.assertEqual(kw, {'some_arg': 'def'}) @@ -621,17 +618,15 @@ def unit_of_work(txn, *args, **kw): time_module = _FauxTimeModule() with _Monkey(MUT, time=time_module): - committed = session.run_in_transaction( - unit_of_work, 'abc', some_arg='def') + session.run_in_transaction(unit_of_work, 'abc', some_arg='def') self.assertEqual(time_module._slept, RETRY_SECONDS + RETRY_NANOS / 1.0e9) - self.assertEqual(committed, now) self.assertEqual(len(called_with), 2) for index, (txn, args, kw) in enumerate(called_with): self.assertIsInstance(txn, Transaction) if index == 1: - self.assertEqual(txn.committed, committed) + self.assertEqual(txn.committed, now) else: self.assertIsNone(txn.committed) self.assertEqual(args, ('abc',)) @@ -688,9 +683,8 @@ def unit_of_work(txn, *args, **kw): time_module = _FauxTimeModule() with _Monkey(MUT, time=time_module): - committed = session.run_in_transaction(unit_of_work) + session.run_in_transaction(unit_of_work) - self.assertEqual(committed, now) self.assertEqual(time_module._slept, RETRY_SECONDS + RETRY_NANOS / 1.0e9) self.assertEqual(len(called_with), 2)