Skip to content

Commit

Permalink
Retry streaming exceptions (this time for sure, Rocky!) (#4016)
Browse files Browse the repository at this point in the history
* Add '_restart_on_unavailable' iterator wrapper.

  Tracks the 'resume_token', and issues restart after a 503.

* Strip knowledge of 'resume_token' from 'StreamedResultSet'.

* Remove 'resume_token' args from 'Snapshot' and 'Session' API surface:

  Retry handling will be done behind the scenes.

* Use '_restart_on_unavailable' wrapper in 'SRS.{read,execute_sql}.

Closes #3775.
  • Loading branch information
tseaver authored Sep 21, 2017
1 parent 1b9edf9 commit 5a0fe35
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 82 deletions.
17 changes: 4 additions & 13 deletions spanner/google/cloud/spanner/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def snapshot(self, **kw):

return Snapshot(self, **kw)

def read(self, table, columns, keyset, index='', limit=0,
resume_token=b''):
def read(self, table, columns, keyset, index='', limit=0):
"""Perform a ``StreamingRead`` API request for rows in a table.
:type table: str
Expand All @@ -185,17 +184,12 @@ def read(self, table, columns, keyset, index='', limit=0,
:type limit: int
:param limit: (Optional) maxiumn number of rows to return
:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted read
:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self.snapshot().read(
table, columns, keyset, index, limit, resume_token)
return self.snapshot().read(table, columns, keyset, index, limit)

def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
resume_token=b''):
def execute_sql(self, sql, params=None, param_types=None, query_mode=None):
"""Perform an ``ExecuteStreamingSql`` API request.
:type sql: str
Expand All @@ -216,14 +210,11 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
:param query_mode: Mode governing return of results / query plan. See
https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1
:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted query
:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self.snapshot().execute_sql(
sql, params, param_types, query_mode, resume_token)
sql, params, param_types, query_mode)

def batch(self):
"""Factory to create a batch for this session.
Expand Down
60 changes: 46 additions & 14 deletions spanner/google/cloud/spanner/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

"""Model a set of read-only queries to a database as a snapshot."""

import functools

from google.protobuf.struct_pb2 import Struct
from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions
from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector

from google.api.core.exceptions import ServiceUnavailable
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud._helpers import _timedelta_to_duration_pb
from google.cloud.spanner._helpers import _make_value_pb
Expand All @@ -26,6 +29,36 @@
from google.cloud.spanner.streamed import StreamedResultSet


def _restart_on_unavailable(restart):
"""Restart iteration after :exc:`.ServiceUnavailable`.
:type restart: callable
:param restart: curried function returning iterator
"""
resume_token = ''
item_buffer = []
iterator = restart()
while True:
try:
for item in iterator:
item_buffer.append(item)
if item.resume_token:
resume_token = item.resume_token
break
except ServiceUnavailable:
del item_buffer[:]
iterator = restart(resume_token=resume_token)
continue

if len(item_buffer) == 0:
break

for item in item_buffer:
yield item

del item_buffer[:]


class _SnapshotBase(_SessionWrapper):
"""Base class for Snapshot.
Expand All @@ -49,8 +82,7 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc
"""
raise NotImplementedError

def read(self, table, columns, keyset, index='', limit=0,
resume_token=b''):
def read(self, table, columns, keyset, index='', limit=0):
"""Perform a ``StreamingRead`` API request for rows in a table.
:type table: str
Expand All @@ -69,9 +101,6 @@ def read(self, table, columns, keyset, index='', limit=0,
:type limit: int
:param limit: (Optional) maxiumn number of rows to return
:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted read
:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises ValueError:
Expand All @@ -89,10 +118,13 @@ def read(self, table, columns, keyset, index='', limit=0,
options = _options_with_prefix(database.name)
transaction = self._make_txn_selector()

iterator = api.streaming_read(
restart = functools.partial(
api.streaming_read,
self._session.name, table, columns, keyset.to_pb(),
transaction=transaction, index=index, limit=limit,
resume_token=resume_token, options=options)
options=options)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1

Expand All @@ -101,8 +133,7 @@ def read(self, table, columns, keyset, index='', limit=0,
else:
return StreamedResultSet(iterator)

def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
resume_token=b''):
def execute_sql(self, sql, params=None, param_types=None, query_mode=None):
"""Perform an ``ExecuteStreamingSql`` API request for rows in a table.
:type sql: str
Expand All @@ -122,9 +153,6 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
:param query_mode: Mode governing return of results / query plan. See
https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1
:type resume_token: bytes
:param resume_token: token for resuming previously-interrupted query
:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises ValueError:
Expand All @@ -150,10 +178,14 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
options = _options_with_prefix(database.name)
transaction = self._make_txn_selector()
api = database.spanner_api
iterator = api.execute_streaming_sql(

restart = functools.partial(
api.execute_streaming_sql,
self._session.name, sql,
transaction=transaction, params=params_pb, param_types=param_types,
query_mode=query_mode, resume_token=resume_token, options=options)
query_mode=query_mode, options=options)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1

Expand Down
11 changes: 0 additions & 11 deletions spanner/google/cloud/spanner/streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(self, response_iterator, source=None):
self._counter = 0 # Counter for processed responses
self._metadata = None # Until set from first PRS
self._stats = None # Until set from last PRS
self._resume_token = None # To resume from last received PRS
self._current_row = [] # Accumulated values for incomplete row
self._pending_chunk = None # Incomplete value
self._source = source # Source snapshot
Expand Down Expand Up @@ -85,15 +84,6 @@ def stats(self):
"""
return self._stats

@property
def resume_token(self):
"""Token for resuming interrupted read / query.
:rtype: bytes
:returns: token from last chunk of results.
"""
return self._resume_token

def _merge_chunk(self, value):
"""Merge pending chunk with next value.
Expand Down Expand Up @@ -132,7 +122,6 @@ def consume_next(self):
"""
response = six.next(self._response_iterator)
self._counter += 1
self._resume_token = response.resume_token

if self._metadata is None: # first response
metadata = self._metadata = response.metadata
Expand Down
21 changes: 8 additions & 13 deletions spanner/tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def test_read(self):
KEYSET = KeySet(keys=KEYS)
INDEX = 'email-address-index'
LIMIT = 20
TOKEN = b'DEADBEEF'
database = _Database(self.DATABASE_NAME)
session = self._make_one(database)
session._session_id = 'DEADBEEF'
Expand All @@ -279,28 +278,26 @@ def __init__(self, session, **kwargs):
self._session = session
self._kwargs = kwargs.copy()

def read(self, table, columns, keyset, index='', limit=0,
resume_token=b''):
def read(self, table, columns, keyset, index='', limit=0):
_read_with.append(
(table, columns, keyset, index, limit, resume_token))
(table, columns, keyset, index, limit))
return expected

with _Monkey(MUT, Snapshot=_Snapshot):
found = session.read(
TABLE_NAME, COLUMNS, KEYSET,
index=INDEX, limit=LIMIT, resume_token=TOKEN)
index=INDEX, limit=LIMIT)

self.assertIs(found, expected)

self.assertEqual(len(_read_with), 1)
(table, columns, key_set, index, limit, resume_token) = _read_with[0]
(table, columns, key_set, index, limit) = _read_with[0]

self.assertEqual(table, TABLE_NAME)
self.assertEqual(columns, COLUMNS)
self.assertEqual(key_set, KEYSET)
self.assertEqual(index, INDEX)
self.assertEqual(limit, LIMIT)
self.assertEqual(resume_token, TOKEN)

def test_execute_sql_not_created(self):
SQL = 'SELECT first_name, age FROM citizens'
Expand Down Expand Up @@ -330,25 +327,23 @@ def __init__(self, session, **kwargs):
self._kwargs = kwargs.copy()

def execute_sql(
self, sql, params=None, param_types=None, query_mode=None,
resume_token=None):
self, sql, params=None, param_types=None, query_mode=None):
_executed_sql_with.append(
(sql, params, param_types, query_mode, resume_token))
(sql, params, param_types, query_mode))
return expected

with _Monkey(MUT, Snapshot=_Snapshot):
found = session.execute_sql(SQL, resume_token=TOKEN)
found = session.execute_sql(SQL)

self.assertIs(found, expected)

self.assertEqual(len(_executed_sql_with), 1)
sql, params, param_types, query_mode, token = _executed_sql_with[0]
sql, params, param_types, query_mode = _executed_sql_with[0]

self.assertEqual(sql, SQL)
self.assertEqual(params, None)
self.assertEqual(param_types, None)
self.assertEqual(query_mode, None)
self.assertEqual(token, TOKEN)

def test_batch_not_created(self):
database = _Database(self.DATABASE_NAME)
Expand Down
Loading

0 comments on commit 5a0fe35

Please sign in to comment.