Skip to content

Commit

Permalink
feat: Batch Write API implementation and samples
Browse files Browse the repository at this point in the history
  • Loading branch information
sunny1612 committed Oct 20, 2023
1 parent 4d490cf commit a32dd63
Show file tree
Hide file tree
Showing 8 changed files with 528 additions and 2 deletions.
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .types.result_set import ResultSetStats
from .types.spanner import BatchCreateSessionsRequest
from .types.spanner import BatchCreateSessionsResponse
from .types.spanner import BatchWriteRequest
from .types.spanner import BatchWriteResponse
from .types.spanner import BeginTransactionRequest
from .types.spanner import CommitRequest
from .types.spanner import CreateSessionRequest
Expand Down Expand Up @@ -99,6 +101,8 @@
# google.cloud.spanner_v1.types
"BatchCreateSessionsRequest",
"BatchCreateSessionsResponse",
"BatchWriteRequest",
"BatchWriteResponse",
"BeginTransactionRequest",
"CommitRequest",
"CommitResponse",
Expand Down
72 changes: 70 additions & 2 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1 import BatchWriteRequest

from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
Expand All @@ -42,9 +43,9 @@ class _BatchBase(_SessionWrapper):
transaction_tag = None
_read_only = False

def __init__(self, session):
def __init__(self, session, mutations=None):
super(_BatchBase, self).__init__(session)
self._mutations = []
self._mutations = [] if mutations is None else mutations

def _check_state(self):
"""Helper for :meth:`commit` et al.
Expand Down Expand Up @@ -215,6 +216,73 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.commit()


class MutationGroups(_SessionWrapper):
"""Accumulate mutations for transmission during :meth:`batch_write`.
:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit
"""

committed = None

def __init__(self, session):
super(MutationGroups, self).__init__(session)
self._mutation_groups = []

def group(self):
"""Returns a new mutation_group to which mutations can be added."""
mutation_group = BatchWriteRequest.MutationGroup()
self._mutation_groups.append(mutation_group)
return _BatchBase(self._session, mutation_group.mutations)

def batch_write(self, request_options=None):
"""Executes batch_write.
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
:returns: a sequence of responses for each batch.
"""
if self.committed is not None:
raise ValueError("MutationGroups already committed")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
trace_attributes = {"num_mutation_groups": len(self._mutation_groups)}
if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

request = BatchWriteRequest(
session=self._session.name,
mutation_groups=self._mutation_groups,
request_options=request_options,
)
with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
method = functools.partial(
api.batch_write,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = True
return response


def _make_write_pb(table, columns, values):
"""Helper for :meth:`Batch.insert` et al.
Expand Down
45 changes: 45 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
Expand Down Expand Up @@ -734,6 +735,17 @@ def batch(self, request_options=None):
"""
return BatchCheckout(self, request_options)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
The wrapper *must* be used as a context manager, with the mutation group
as the value returned by the wrapper.
:rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout`
:returns: new wrapper
"""
return MutationGroupsCheckout(self)

def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
"""Return an object which wraps a batch read / query.
Expand Down Expand Up @@ -1040,6 +1052,39 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._database._pool.put(self._session)


class MutationGroupsCheckout(object):
"""Context manager for using mutation groups from a database.
Inside the context manager, checks out a session from the database,
creates mutation groups from it, making the groups available.
Caller must *not* use the object to perform API requests outside the scope
of the context manager.
:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: database to use
"""

def __init__(self, database):
self._database = database
self._session = self._mutation_groups = None

def __enter__(self):
"""Begin ``with`` block."""
session = self._session = self._database._pool.get()
return MutationGroups(session)

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
self._database._pool.put(self._session)


class SnapshotCheckout(object):
"""Context manager for using a snapshot from a database.
Expand Down
51 changes: 51 additions & 0 deletions samples/samples/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,54 @@ def insert_data(instance_id, database_id):
# [END spanner_insert_data]


# [START spanner_batch_write]
def batch_write(instance_id, database_id):
"""Inserts sample data into the given database via BatchWrite API.
The database and table must already exist and can be created using
`create_database`.
"""
spanner_client = spanner.Client()
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)

with database.mutation_groups() as groups:
group1 = groups.group()
group1.insert_or_update(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(16, "Scarlet", "Terry"),
],
)

group2 = groups.group()
group2.insert_or_update(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(17, "Marc", "Richards"),
(18, "Catalina", "Smith"),
],
)
group2.insert_or_update(
table="Albums",
columns=("SingerId", "AlbumId", "AlbumTitle"),
values=[
(17, 1, "Total Junk"),
(18, 2, "Go, Go, Go"),
],
)

for response in groups.batch_write():
print(response)

print("Inserted data.")


# [END spanner_batch_write]


# [START spanner_delete_data]
def delete_data(instance_id, database_id):
"""Deletes sample data from the given database.
Expand Down Expand Up @@ -2677,6 +2725,7 @@ def drop_sequence(instance_id, database_id):
subparsers.add_parser("create_instance", help=create_instance.__doc__)
subparsers.add_parser("create_database", help=create_database.__doc__)
subparsers.add_parser("insert_data", help=insert_data.__doc__)
subparsers.add_parser("batch_write", help=batch_write.__doc__)
subparsers.add_parser("delete_data", help=delete_data.__doc__)
subparsers.add_parser("query_data", help=query_data.__doc__)
subparsers.add_parser("read_data", help=read_data.__doc__)
Expand Down Expand Up @@ -2811,6 +2860,8 @@ def drop_sequence(instance_id, database_id):
create_database(args.instance_id, args.database_id)
elif args.command == "insert_data":
insert_data(args.instance_id, args.database_id)
elif args.command == "batch_write":
batch_write(args.instance_id, args.database_id)
elif args.command == "delete_data":
delete_data(args.instance_id, args.database_id)
elif args.command == "query_data":
Expand Down
7 changes: 7 additions & 0 deletions samples/samples/snippets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ def test_insert_data(capsys, instance_id, sample_database):
assert "Inserted data" in out


@pytest.mark.dependency(name="batch_write")
def test_batch_write(capsys, instance_id, sample_database):
snippets.batch_write(instance_id, sample_database.database_id)
out, _ = capsys.readouterr()
assert "Inserted data" in out


@pytest.mark.dependency(depends=["insert_data"])
def test_delete_data(capsys, instance_id, sample_database):
snippets.delete_data(instance_id, sample_database.database_id)
Expand Down
38 changes: 38 additions & 0 deletions tests/system/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,44 @@ def test_partition_query(sessions_database, not_emulator):
batch_txn.close()


def test_mutation_groups_insert_or_update_then_query(sessions_database):
sd = _sample_data
ROW_DATA = (
(1, "Phred", "Phlyntstone", "phred@example.com"),
(2, "Bharney", "Rhubble", "bharney@example.com"),
(3, "Wylma", "Phlyntstone", "wylma@example.com"),
(4, "Pebbles", "Phlyntstone", "pebbles@example.com"),
(5, "Betty", "Rhubble", "betty@example.com"),
(6, "Slate", "Stephenson", "slate@example.com"),
)
num_groups = 3
num_mutations_per_group = len(ROW_DATA) // num_groups

with sessions_database.mutation_groups() as groups:
for i in range(num_groups):
group = groups.group()
for j in range(num_mutations_per_group):
group.insert_or_update(
sd.TABLE, sd.COLUMNS, [ROW_DATA[i * num_mutations_per_group + j]]
)
# Response indexes received
seen = collections.Counter()
for response in groups.batch_write():
_check_batch_status(response.status.code)
assert response.commit_timestamp is not None
assert len(response.indexes) > 0
seen.update(response.indexes)
# All indexes must be in the range [0, num_groups-1] and seen exactly once
assert len(seen) == num_groups
assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items())

# Verify the writes by reading from the database
with sessions_database.snapshot() as snapshot:
rows = list(snapshot.execute_sql(sd.SQL))

sd._check_rows_data(rows, ROW_DATA)


class FauxCall:
def __init__(self, code, details="FauxCall"):
self._code = code
Expand Down
Loading

0 comments on commit a32dd63

Please sign in to comment.