Skip to content

Commit db4e343

Browse files
committed
Add 'BatchTransaction' wrapper class (#438)
Encapsulates session ID / transaction ID, to be marshalled across the wire to another process / host for performing partitioned reads / queries.
1 parent 874b69d commit db4e343

File tree

5 files changed

+1064
-30
lines changed

5 files changed

+1064
-30
lines changed

spanner/google/cloud/spanner_v1/database.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
2828
from google.cloud.spanner_v1.batch import Batch
2929
from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient
30+
from google.cloud.spanner_v1.keyset import KeySet
3031
from google.cloud.spanner_v1.pool import BurstyPool
3132
from google.cloud.spanner_v1.pool import SessionCheckout
3233
from google.cloud.spanner_v1.session import Session
@@ -308,6 +309,14 @@ def batch(self):
308309
"""
309310
return BatchCheckout(self)
310311

312+
def batch_transaction(self):
313+
"""Return an object which wraps a batch read / query.
314+
315+
:rtype: :class:`~google.cloud.spanner_v1.database.BatchTransaction`
316+
:returns: new wrapper
317+
"""
318+
return BatchTransaction(self)
319+
311320
def run_in_transaction(self, func, *args, **kw):
312321
"""Perform a unit of work in a transaction, retrying on abort.
313322
@@ -406,6 +415,263 @@ def __exit__(self, exc_type, exc_val, exc_tb):
406415
self._database._pool.put(self._session)
407416

408417

418+
class BatchTransaction(object):
419+
"""Wrapper for generating and processing read / query batches.
420+
421+
:type database: :class:`~google.cloud.spannder.database.Database`
422+
:param database: database to use
423+
424+
:type read_timestamp: :class:`datetime.datetime`
425+
:param read_timestamp: Execute all reads at the given timestamp.
426+
427+
:type min_read_timestamp: :class:`datetime.datetime`
428+
:param min_read_timestamp: Execute all reads at a
429+
timestamp >= ``min_read_timestamp``.
430+
431+
:type max_staleness: :class:`datetime.timedelta`
432+
:param max_staleness: Read data at a
433+
timestamp >= NOW - ``max_staleness`` seconds.
434+
435+
:type exact_staleness: :class:`datetime.timedelta`
436+
:param exact_staleness: Execute all reads at a timestamp that is
437+
``exact_staleness`` old.
438+
"""
439+
def __init__(
440+
self, database,
441+
read_timestamp=None,
442+
min_read_timestamp=None,
443+
max_staleness=None,
444+
exact_staleness=None):
445+
446+
self._database = database
447+
self._session = None
448+
self._snapshot = None
449+
self._read_timestamp = read_timestamp
450+
self._min_read_timestamp = min_read_timestamp
451+
self._max_staleness = max_staleness
452+
self._exact_staleness = exact_staleness
453+
454+
@classmethod
455+
def from_dict(cls, database, mapping):
456+
"""Reconstruct an instance from a mapping.
457+
458+
:type database: :class:`~google.cloud.spannder.database.Database`
459+
:param database: database to use
460+
461+
:type mapping: mapping
462+
:param mapping: serialized state of the instance
463+
464+
:rtype: :class:`BatchTransaction`
465+
"""
466+
instance = cls(database)
467+
session = instance._session = database.session()
468+
session._session_id = mapping['session_id']
469+
txn = session.transaction()
470+
txn._transaction_id = mapping['transaction_id']
471+
return instance
472+
473+
def to_dict(self):
474+
"""Return state as a dictionary.
475+
476+
Result can be used to serialize the instance and reconstitute
477+
it later using :meth:`from_dict`.
478+
479+
:rtype: dict
480+
"""
481+
session = self._get_session()
482+
return {
483+
'session_id': session._session_id,
484+
'transaction_id': session._transaction._transaction_id,
485+
}
486+
487+
def _get_session(self):
488+
"""Create session as needed.
489+
490+
.. note::
491+
492+
Caller is responsible for cleaning up the session after
493+
all partitions have been processed.
494+
"""
495+
if self._session is None:
496+
session = self._session = self._database.session()
497+
session.create()
498+
txn = session.transaction()
499+
txn.begin()
500+
return self._session
501+
502+
def _get_snapshot(self):
503+
"""Create snapshot if needed."""
504+
if self._snapshot is None:
505+
self._snapshot = self._get_session().snapshot(
506+
read_timestamp=self._read_timestamp,
507+
min_read_timestamp=self._min_read_timestamp,
508+
max_staleness=self._max_staleness,
509+
exact_staleness=self._exact_staleness,
510+
multi_use=True)
511+
return self._snapshot
512+
513+
def generate_read_batches(
514+
self, table, columns, keyset,
515+
index='', partition_size_bytes=None, max_partitions=None):
516+
"""Start a partitioned batch read operation.
517+
518+
Uses the ``PartitionRead`` API request to initiate the partitioned
519+
read. Returns a list of batch information needed to perform the
520+
actual reads.
521+
522+
:type table: str
523+
:param table: name of the table from which to fetch data
524+
525+
:type columns: list of str
526+
:param columns: names of columns to be retrieved
527+
528+
:type keyset: :class:`~google.cloud.spanner_v1.keyset.KeySet`
529+
:param keyset: keys / ranges identifying rows to be retrieved
530+
531+
:type index: str
532+
:param index: (Optional) name of index to use, rather than the
533+
table's primary key
534+
535+
:type partition_size_bytes: int
536+
:param partition_size_bytes:
537+
(Optional) desired size for each partition generated. The service
538+
uses this as a hint, the actual partition size may differ.
539+
540+
:type max_partitions: int
541+
:param max_partitions:
542+
(Optional) desired maximum number of partitions generated. The
543+
service uses this as a hint, the actual number of partitions may
544+
differ.
545+
546+
:rtype: iterable of dict
547+
:returns:
548+
mappings of information used peform actual partitioned reads via
549+
:meth:`process_read_batch`.
550+
"""
551+
partitions = self._get_snapshot().partition_read(
552+
table=table, columns=columns, keyset=keyset, index=index,
553+
partition_size_bytes=partition_size_bytes,
554+
max_partitions=max_partitions)
555+
556+
read_info = {
557+
'table': table,
558+
'columns': columns,
559+
'keyset': keyset._to_dict(),
560+
'index': index,
561+
}
562+
for partition in partitions:
563+
yield {'partition': partition, 'read': read_info.copy()}
564+
565+
def process_read_batch(self, batch):
566+
"""Process a single, partitioned read.
567+
568+
:type batch: mapping
569+
:param batch:
570+
one of the mappings returned from an earlier call to
571+
:meth:`generate_read_batches`.
572+
573+
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
574+
:returns: a result set instance which can be used to consume rows.
575+
"""
576+
kwargs = batch['read']
577+
keyset_dict = kwargs.pop('keyset')
578+
kwargs['keyset'] = KeySet._from_dict(keyset_dict)
579+
return self._get_snapshot().read(
580+
partition=batch['partition'], **kwargs)
581+
582+
def generate_query_batches(
583+
self, sql, params=None, param_types=None,
584+
partition_size_bytes=None, max_partitions=None):
585+
"""Start a partitioned query operation.
586+
587+
Uses the ``PartitionQuery`` API request to start a partitioned
588+
query operation. Returns a list of batch information needed to
589+
peform the actual queries.
590+
591+
:type sql: str
592+
:param sql: SQL query statement
593+
594+
:type params: dict, {str -> column value}
595+
:param params: values for parameter replacement. Keys must match
596+
the names used in ``sql``.
597+
598+
:type param_types: dict[str -> Union[dict, .types.Type]]
599+
:param param_types:
600+
(Optional) maps explicit types for one or more param values;
601+
required if parameters are passed.
602+
603+
:type partition_size_bytes: int
604+
:param partition_size_bytes:
605+
(Optional) desired size for each partition generated. The service
606+
uses this as a hint, the actual partition size may differ.
607+
608+
:type partition_size_bytes: int
609+
:param partition_size_bytes:
610+
(Optional) desired size for each partition generated. The service
611+
uses this as a hint, the actual partition size may differ.
612+
613+
:type max_partitions: int
614+
:param max_partitions:
615+
(Optional) desired maximum number of partitions generated. The
616+
service uses this as a hint, the actual number of partitions may
617+
differ.
618+
619+
:rtype: iterable of dict
620+
:returns:
621+
mappings of information used peform actual partitioned reads via
622+
:meth:`process_read_batch`.
623+
"""
624+
partitions = self._get_snapshot().partition_query(
625+
sql=sql, params=params, param_types=param_types,
626+
partition_size_bytes=partition_size_bytes,
627+
max_partitions=max_partitions)
628+
629+
query_info = {'sql': sql}
630+
if params:
631+
query_info['params'] = params
632+
query_info['param_types'] = param_types
633+
634+
for partition in partitions:
635+
yield {'partition': partition, 'query': query_info}
636+
637+
def process_query_batch(self, batch):
638+
"""Process a single, partitioned query.
639+
640+
:type batch: mapping
641+
:param batch:
642+
one of the mappings returned from an earlier call to
643+
:meth:`generate_query_batches`.
644+
645+
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
646+
:returns: a result set instance which can be used to consume rows.
647+
"""
648+
return self._get_snapshot().execute_sql(
649+
partition=batch['partition'], **batch['query'])
650+
651+
def process(self, batch):
652+
"""Process a single, partitioned query or read.
653+
654+
:type batch: mapping
655+
:param batch:
656+
one of the mappings returned from an earlier call to
657+
:meth:`generate_query_batches`.
658+
659+
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
660+
:returns: a result set instance which can be used to consume rows.
661+
:raises ValueError: if batch does not contain either 'read' or 'query'
662+
"""
663+
if 'query' in batch:
664+
return self.process_query_batch(batch)
665+
if 'read' in batch:
666+
return self.process_read_batch(batch)
667+
raise ValueError("Invalid batch")
668+
669+
def close(self):
670+
"""Clean up underlying session."""
671+
if self._session is not None:
672+
self._session.delete()
673+
674+
409675
def _check_ddl_statements(value):
410676
"""Validate DDL Statements used to define database schema.
411677

spanner/google/cloud/spanner_v1/keyset.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ def _to_pb(self):
8585

8686
return KeyRangePB(**kwargs)
8787

88+
def _to_dict(self):
89+
"""Return keyrange's state as a dict.
90+
91+
:rtype: dict
92+
:returns: state of this instance.
93+
"""
94+
mapping = {}
95+
96+
if self.start_open:
97+
mapping['start_open'] = self.start_open
98+
99+
if self.start_closed:
100+
mapping['start_closed'] = self.start_closed
101+
102+
if self.end_open:
103+
mapping['end_open'] = self.end_open
104+
105+
if self.end_closed:
106+
mapping['end_closed'] = self.end_closed
107+
108+
return mapping
109+
110+
def __eq__(self, other):
111+
"""Compare by serialized state."""
112+
if not isinstance(other, self.__class__):
113+
return NotImplemented
114+
return self._to_dict() == other._to_dict()
115+
116+
88117

89118
class KeySet(object):
90119
"""Identify table rows via keys / ranges.
@@ -122,3 +151,41 @@ def _to_pb(self):
122151
kwargs['ranges'] = [krange._to_pb() for krange in self.ranges]
123152

124153
return KeySetPB(**kwargs)
154+
155+
def _to_dict(self):
156+
"""Return keyset's state as a dict.
157+
158+
The result can be used to serialize the instance and reconstitute
159+
it later using :meth:`_from_dict`.
160+
161+
:rtype: dict
162+
:returns: state of this instance.
163+
"""
164+
if self.all_:
165+
return {'all': True}
166+
167+
return {
168+
'keys': self.keys,
169+
'ranges': [keyrange._to_dict() for keyrange in self.ranges],
170+
}
171+
172+
def __eq__(self, other):
173+
"""Compare by serialized state."""
174+
if not isinstance(other, self.__class__):
175+
return NotImplemented
176+
return self._to_dict() == other._to_dict()
177+
178+
@classmethod
179+
def _from_dict(cls, mapping):
180+
"""Create an instance from the corresponding state mapping.
181+
182+
:type mapping: dict
183+
:param mapping: the instance state.
184+
"""
185+
if mapping.get('all'):
186+
return cls(all_=True)
187+
188+
r_mappings = mapping.get('ranges', ())
189+
ranges = [KeyRange(**r_mapping) for r_mapping in r_mappings]
190+
191+
return cls(keys=mapping.get('keys', ()), ranges=ranges)

0 commit comments

Comments
 (0)