Skip to content

Commit

Permalink
Use '_restart_on_unavailable' wrapper in 'SRS.{read,execute_sql}.
Browse files Browse the repository at this point in the history
Closes #3775.
  • Loading branch information
tseaver committed Sep 21, 2017
1 parent 2fe76c8 commit 056f3c9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 11 additions & 2 deletions spanner/google/cloud/spanner/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Model a set of read-only queries to a database as a snapshot."""

import functools

from google.protobuf.struct_pb2 import Struct
from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions
from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector
Expand Down Expand Up @@ -116,11 +118,14 @@ def read(self, table, columns, keyset, index='', limit=0):
options = _options_with_prefix(database.name)
transaction = self._make_txn_selector()

iterator = api.streaming_read(
restart = functools.partial(
api.streaming_read,
self._session.name, table, columns, keyset.to_pb(),
transaction=transaction, index=index, limit=limit,
options=options)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1

if self._multi_use:
Expand Down Expand Up @@ -173,11 +178,15 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None):
options = _options_with_prefix(database.name)
transaction = self._make_txn_selector()
api = database.spanner_api
iterator = api.execute_streaming_sql(

restart = functools.partial(
api.execute_streaming_sql,
self._session.name, sql,
transaction=transaction, params=params_pb, param_types=param_types,
query_mode=query_mode, options=options)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1

if self._multi_use:
Expand Down
4 changes: 2 additions & 2 deletions spanner/tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_read_grpc_error(self):
derived = self._makeDerived(session)

with self.assertRaises(GaxError):
derived.read(TABLE_NAME, COLUMNS, KEYSET)
list(derived.read(TABLE_NAME, COLUMNS, KEYSET))

(r_session, table, columns, key_set, transaction, index,
limit, resume_token, options) = api._streaming_read_with
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_execute_sql_grpc_error(self):
derived = self._makeDerived(session)

with self.assertRaises(GaxError):
derived.execute_sql(SQL_QUERY)
list(derived.execute_sql(SQL_QUERY))

(r_session, sql, transaction, params, param_types,
resume_token, query_mode, options) = api._executed_streaming_sql_with
Expand Down

0 comments on commit 056f3c9

Please sign in to comment.