diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index 47ebfebd..1384f332 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -442,13 +442,11 @@ def _next_page(self): return None query_pb = self._build_protobuf() - transaction = self.client.current_transaction - if transaction is None: - transaction_id = None - else: - transaction_id = transaction.id + transaction_id, new_transaction_options = helpers.get_transaction_options( + self.client.current_transaction + ) read_options = helpers.get_read_options( - self._eventual, transaction_id, self._read_time + self._eventual, transaction_id, self._read_time, new_transaction_options ) partition_id = entity_pb2.PartitionId( diff --git a/google/cloud/datastore/batch.py b/google/cloud/datastore/batch.py index e0dbf26d..69100bc6 100644 --- a/google/cloud/datastore/batch.py +++ b/google/cloud/datastore/batch.py @@ -192,6 +192,19 @@ def mutations(self): """ return self._mutations + def _allow_mutations(self) -> bool: + """ + This method is called to see if the batch is in a proper state to allow + `put` and `delete` operations. + + the Transaction subclass overrides this method to support + the `begin_later` flag. + + :rtype: bool + :returns: True if the batch is in a state to allow mutations. + """ + return self._status == self._IN_PROGRESS + def put(self, entity): """Remember an entity's state to be saved during :meth:`commit`. @@ -218,7 +231,7 @@ def put(self, entity): progress, if entity has no key assigned, or if the key's ``project`` does not match ours. """ - if self._status != self._IN_PROGRESS: + if not self._allow_mutations(): raise ValueError("Batch must be in progress to put()") if entity.key is None: @@ -248,7 +261,7 @@ def delete(self, key): progress, if key is not complete, or if the key's ``project`` does not match ours. """ - if self._status != self._IN_PROGRESS: + if not self._allow_mutations(): raise ValueError("Batch must be in progress to delete()") if key.is_partial: @@ -370,10 +383,12 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): try: - if exc_type is None: - self.commit() - else: - self.rollback() + # commit or rollback if not in terminal state + if self._status not in (self._ABORTED, self._FINISHED): + if exc_type is None: + self.commit() + else: + self.rollback() finally: self._client._pop_batch() diff --git a/google/cloud/datastore/client.py b/google/cloud/datastore/client.py index 3f5041d6..b1e79d91 100644 --- a/google/cloud/datastore/client.py +++ b/google/cloud/datastore/client.py @@ -122,7 +122,7 @@ def _extended_lookup( missing=None, deferred=None, eventual=False, - transaction_id=None, + transaction=None, retry=None, timeout=None, read_time=None, @@ -158,10 +158,10 @@ def _extended_lookup( consistency. If True, request ``EVENTUAL`` read consistency. - :type transaction_id: str - :param transaction_id: If passed, make the request in the scope of - the given transaction. Incompatible with - ``eventual==True`` or ``read_time``. + :type transaction: Transaction + :param transaction: If passed, make the request in the scope of + the given transaction. Incompatible with + ``eventual==True`` or ``read_time``. :type retry: :class:`google.api_core.retry.Retry` :param retry: @@ -177,7 +177,7 @@ def _extended_lookup( :type read_time: datetime :param read_time: (Optional) Read time to use for read consistency. Incompatible with - ``eventual==True`` or ``transaction_id``. + ``eventual==True`` or ``transaction``. This feature is in private preview. :type database: str @@ -199,8 +199,14 @@ def _extended_lookup( results = [] + transaction_id = None + transaction_id, new_transaction_options = helpers.get_transaction_options( + transaction + ) + read_options = helpers.get_read_options( + eventual, transaction_id, read_time, new_transaction_options + ) loop_num = 0 - read_options = helpers.get_read_options(eventual, transaction_id, read_time) while loop_num < _MAX_LOOPS: # loop against possible deferred. loop_num += 1 request = { @@ -214,6 +220,10 @@ def _extended_lookup( **kwargs, ) + # set new transaction id if we just started a transaction + if transaction and lookup_response.transaction: + transaction._begin_with_id(lookup_response.transaction) + # Accumulate the new results. results.extend(result.entity for result in lookup_response.found) @@ -570,7 +580,7 @@ def get_multi( eventual=eventual, missing=missing, deferred=deferred, - transaction_id=transaction and transaction.id, + transaction=transaction, retry=retry, timeout=timeout, read_time=read_time, diff --git a/google/cloud/datastore/helpers.py b/google/cloud/datastore/helpers.py index e8894883..6eaa3b89 100644 --- a/google/cloud/datastore/helpers.py +++ b/google/cloud/datastore/helpers.py @@ -230,7 +230,9 @@ def entity_to_protobuf(entity): return entity_pb -def get_read_options(eventual, transaction_id, read_time=None): +def get_read_options( + eventual, transaction_id, read_time=None, new_transaction_options=None +): """Validate rules for read options, and assign to the request. Helper method for ``lookup()`` and ``run_query``. @@ -245,33 +247,55 @@ def get_read_options(eventual, transaction_id, read_time=None): :type read_time: datetime :param read_time: Read data from the specified time (may be null). This feature is in private preview. + :type new_transaction_options: :class:`google.cloud.datastore_v1.types.TransactionOptions` + :param new_transaction_options: Options for a new transaction. + :rtype: :class:`.datastore_pb2.ReadOptions` :returns: The read options corresponding to the inputs. :raises: :class:`ValueError` if more than one of ``eventual==True``, - ``transaction``, and ``read_time`` is specified. + ``transaction_id``, ``read_time``, and ``new_transaction_options`` is specified. """ - if transaction_id is None: - if eventual: - if read_time is not None: - raise ValueError("eventual must be False when read_time is specified") - else: - return datastore_pb2.ReadOptions( - read_consistency=datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL - ) - else: - if read_time is None: - return datastore_pb2.ReadOptions() - else: - read_time_pb = timestamp_pb2.Timestamp() - read_time_pb.FromDatetime(read_time) - return datastore_pb2.ReadOptions(read_time=read_time_pb) - else: - if eventual: - raise ValueError("eventual must be False when in a transaction") - elif read_time is not None: - raise ValueError("transaction and read_time are mutual exclusive") - else: - return datastore_pb2.ReadOptions(transaction=transaction_id) + is_set = [ + bool(x) for x in (eventual, transaction_id, read_time, new_transaction_options) + ] + if sum(is_set) > 1: + raise ValueError( + "At most one of eventual, transaction, or read_time is allowed." + ) + new_options = datastore_pb2.ReadOptions() + if transaction_id is not None: + new_options.transaction = transaction_id + if read_time is not None: + read_time_pb = timestamp_pb2.Timestamp() + read_time_pb.FromDatetime(read_time) + new_options.read_time = read_time_pb + if new_transaction_options is not None: + new_options.new_transaction = new_transaction_options + if eventual: + new_options.read_consistency = ( + datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL + ) + return new_options + + +def get_transaction_options(transaction): + """ + Get the transaction_id or new_transaction_options field from an active transaction object, + for use in get_read_options + + These are mutually-exclusive fields, so one or both will be None. + + :rtype: Tuple[Optional[bytes], Optional[google.cloud.datastore_v1.types.TransactionOptions]] + :returns: The transaction_id and new_transaction_options fields from the transaction object. + """ + transaction_id, new_transaction_options = None, None + if transaction is not None: + if transaction.id is not None: + transaction_id = transaction.id + elif transaction._begin_later and transaction._status == transaction._INITIAL: + # If the transaction has not yet been begun, we can use the new_transaction_options field. + new_transaction_options = transaction._options + return transaction_id, new_transaction_options def key_from_protobuf(pb): diff --git a/google/cloud/datastore/query.py b/google/cloud/datastore/query.py index 57c0702c..72d6fe51 100644 --- a/google/cloud/datastore/query.py +++ b/google/cloud/datastore/query.py @@ -778,13 +778,12 @@ def _next_page(self): return None query_pb = self._build_protobuf() - transaction = self.client.current_transaction - if transaction is None: - transaction_id = None - else: - transaction_id = transaction.id + new_transaction_options = None + transaction_id, new_transaction_options = helpers.get_transaction_options( + self.client.current_transaction + ) read_options = helpers.get_read_options( - self._eventual, transaction_id, self._read_time + self._eventual, transaction_id, self._read_time, new_transaction_options ) partition_id = entity_pb2.PartitionId( diff --git a/google/cloud/datastore/transaction.py b/google/cloud/datastore/transaction.py index 3e71ae26..52c17ce2 100644 --- a/google/cloud/datastore/transaction.py +++ b/google/cloud/datastore/transaction.py @@ -13,7 +13,6 @@ # limitations under the License. """Create / interact with Google Cloud Datastore transactions.""" - from google.cloud.datastore.batch import Batch from google.cloud.datastore_v1.types import TransactionOptions from google.protobuf import timestamp_pb2 @@ -149,15 +148,23 @@ class Transaction(Batch): :param read_time: (Optional) Time at which the transaction reads entities. Only allowed when ``read_only=True``. This feature is in private preview. + :type begin_later: bool + :param begin_later: (Optional) If True, the transaction will be started + lazily (i.e. when the first RPC is made). If False, + the transaction will be started as soon as the context manager + is entered. `self.begin()` can also be called manually to begin + the transaction at any time. Default is False. + :raises: :class:`ValueError` if read_time is specified when ``read_only=False``. """ _status = None - def __init__(self, client, read_only=False, read_time=None): + def __init__(self, client, read_only=False, read_time=None, begin_later=False): super(Transaction, self).__init__(client) self._id = None + self._begin_later = begin_later if read_only: if read_time is not None: @@ -180,8 +187,8 @@ def __init__(self, client, read_only=False, read_time=None): def id(self): """Getter for the transaction ID. - :rtype: str - :returns: The ID of the current transaction. + :rtype: bytes or None + :returns: The ID of the current transaction, or None if not started. """ return self._id @@ -240,6 +247,21 @@ def begin(self, retry=None, timeout=None): self._status = self._ABORTED raise + def _begin_with_id(self, transaction_id): + """ + Attach newly created transaction to an existing transaction ID. + + This is used when begin_later is True, when the first lookup request + associated with this transaction creates a new transaction ID. + + :type transaction_id: bytes + :param transaction_id: ID of the transaction to attach to. + """ + if self._status is not self._INITIAL: + raise ValueError("Transaction already begun.") + self._id = transaction_id + self._status = self._IN_PROGRESS + def rollback(self, retry=None, timeout=None): """Rolls back the current transaction. @@ -258,6 +280,12 @@ def rollback(self, retry=None, timeout=None): Note that if ``retry`` is specified, the timeout applies to each individual attempt. """ + # if transaction has not started, abort it + if self._status == self._INITIAL: + self._status = self._ABORTED + self._id = None + return None + kwargs = _make_retry_timeout_kwargs(retry, timeout) try: @@ -296,6 +324,15 @@ def commit(self, retry=None, timeout=None): Note that if ``retry`` is specified, the timeout applies to each individual attempt. """ + # if transaction has not begun, either begin now, or abort if empty + if self._status == self._INITIAL: + if not self._mutations: + self._status = self._ABORTED + self._id = None + return None + else: + self.begin() + kwargs = _make_retry_timeout_kwargs(retry, timeout) try: @@ -321,3 +358,18 @@ def put(self, entity): raise RuntimeError("Transaction is read only") else: super(Transaction, self).put(entity) + + def __enter__(self): + if not self._begin_later: + self.begin() + self._client._push_batch(self) + return self + + def _allow_mutations(self): + """ + Mutations can be added to a transaction if it is in IN_PROGRESS state, + or if it is in INITIAL state and the begin_later flag is set. + """ + return self._status == self._IN_PROGRESS or ( + self._begin_later and self._status == self._INITIAL + ) diff --git a/tests/system/test_transaction.py b/tests/system/test_transaction.py index 6dc9dacd..2f7a6897 100644 --- a/tests/system/test_transaction.py +++ b/tests/system/test_transaction.py @@ -41,6 +41,57 @@ def test_transaction_via_with_statement( assert retrieved_entity == entity +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +@pytest.mark.parametrize("first_call", ["get", "put", "delete"]) +def test_transaction_begin_later( + datastore_client, entities_to_delete, database_id, first_call +): + """ + transactions with begin_later should call begin on first get rpc, or on commit + """ + key = datastore_client.key("Company", "Google") + entity = datastore.Entity(key=key) + entity["url"] = "www.google.com" + + datastore_client.put(entity) + result_entity = datastore_client.get(key) + + with datastore_client.transaction(begin_later=True) as xact: + assert xact._id is None + assert xact._status == xact._INITIAL + if first_call == "get": + datastore_client.get(entity.key) + assert xact._status == xact._IN_PROGRESS + assert xact._id is not None + elif first_call == "put": + xact.put(entity) + assert xact._status == xact._INITIAL + elif first_call == "delete": + xact.delete(result_entity.key) + assert xact._status == xact._INITIAL + assert xact._status == xact._FINISHED + + entities_to_delete.append(result_entity) + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +@pytest.mark.parametrize("raise_exception", [True, False]) +def test_transaction_begin_later_noop(datastore_client, database_id, raise_exception): + """ + empty begin later transactions should terminate quietly + """ + try: + with datastore_client.transaction(begin_later=True) as xact: + assert xact._id is None + assert xact._status == xact._INITIAL + if raise_exception: + raise RuntimeError("test") + except RuntimeError: + pass + assert xact._status == xact._ABORTED + assert xact._id is None + + @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) def test_transaction_via_explicit_begin_get_commit( datastore_client, entities_to_delete, database_id diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index 15d11aca..8284b808 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -471,7 +471,9 @@ def _next_page_helper(txn_id=None, retry=None, timeout=None, database_id=None): if txn_id is None: client = _Client(project, datastore_api=ds_api, database=database_id) else: - transaction = mock.Mock(id=txn_id, spec=["id"]) + transaction = mock.Mock( + id=txn_id, _begin_later=False, spec=["id", "_begin_later"] + ) client = _Client( project, datastore_api=ds_api, transaction=transaction, database=database_id ) @@ -612,6 +614,57 @@ def test_transaction_id_populated(database_id, aggregation_type, aggregation_arg assert read_options.transaction == client.current_transaction.id +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +@pytest.mark.parametrize( + "aggregation_type,aggregation_args", + [ + ("count", ()), + ( + "sum", + ("appearances",), + ), + ("avg", ("appearances",)), + ], +) +def test_transaction_begin_later(database_id, aggregation_type, aggregation_args): + """ + When an aggregation is run in the context of a transaction with begin_later=True, + the new_transaction field should be populated in the request read_options. + """ + import mock + from google.cloud.datastore_v1.types import TransactionOptions + + # make a fake begin_later transaction + transaction = mock.Mock() + transaction.id = None + transaction._begin_later = True + transaction._status = transaction._INITIAL + transaction._options = TransactionOptions(read_only=TransactionOptions.ReadOnly()) + mock_datastore_api = mock.Mock() + mock_gapic = mock_datastore_api.run_aggregation_query + mock_gapic.return_value = _make_aggregation_query_response([]) + client = _Client( + None, + datastore_api=mock_datastore_api, + database=database_id, + transaction=transaction, + ) + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + # initiate requested aggregation (ex count, sum, avg) + getattr(aggregation_query, aggregation_type)(*aggregation_args) + # run mock query + list(aggregation_query.fetch()) + assert mock_gapic.call_count == 1 + request = mock_gapic.call_args[1]["request"] + read_options = request["read_options"] + # ensure new_transaction is populated + assert not read_options.transaction + assert read_options.new_transaction == transaction._options + + class _Client(object): def __init__( self, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 412f3923..2b5c01f4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -705,6 +705,52 @@ def test_client_get_multi_hit_w_transaction(database_id): ds_api.lookup.assert_called_once_with(request=expected_request) +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_client_get_multi_hit_w_transaction_begin_later(database_id): + """ + Transactions with begin_later set should begin on first read + """ + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.key import Key + + kind = "Kind" + id_ = 1234 + expected_server_id = b"123" + + # Make a found entity pb to be returned from mock backend. + entity_pb = _make_entity_pb(PROJECT, kind, id_, "foo", "Foo", database=database_id) + + # Make a connection to return the entity pb. + creds = _make_credentials() + client = _make_client(credentials=creds, database=database_id) + lookup_response = _make_lookup_response( + results=[entity_pb], transaction=expected_server_id + ) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api + + key = Key(kind, id_, project=PROJECT, database=database_id) + txn = client.transaction(begin_later=True) + assert txn._id is None + assert txn._status == txn._INITIAL + client.get_multi([key], transaction=txn) + + # transaction should now be started + assert txn._id == expected_server_id + assert txn._id is not None + assert txn._status == txn._IN_PROGRESS + + # check rpc args + expected_read_options = datastore_pb2.ReadOptions(new_transaction=txn._options) + expected_request = { + "project_id": PROJECT, + "keys": [key.to_protobuf()], + "read_options": expected_read_options, + } + set_database_id_to_request(expected_request, database_id) + ds_api.lookup.assert_called_once_with(request=expected_request) + + @pytest.mark.parametrize("database_id", [None, "somedb"]) def test_client_get_multi_hit_w_read_time(database_id): from datetime import datetime @@ -1847,7 +1893,7 @@ def _make_commit_response(*keys): return datastore_pb2.CommitResponse(mutation_results=mutation_results) -def _make_lookup_response(results=(), missing=(), deferred=()): +def _make_lookup_response(results=(), missing=(), deferred=(), transaction=None): entity_results_found = [ mock.Mock(entity=result, spec=["entity"]) for result in results ] @@ -1858,7 +1904,8 @@ def _make_lookup_response(results=(), missing=(), deferred=()): found=entity_results_found, missing=entity_results_missing, deferred=deferred, - spec=["found", "missing", "deferred"], + transaction=transaction, + spec=["found", "missing", "deferred", "transaction"], ) diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 89bf6165..38702dba 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -586,6 +586,91 @@ def test__get_read_options_w_default_wo_txn_w_read_time(): assert read_options == expected +def test__get_read_options_w_new_transaction(): + from google.cloud.datastore.helpers import get_read_options + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + input_options = datastore_pb2.TransactionOptions() + read_options = get_read_options(False, None, new_transaction_options=input_options) + expected = datastore_pb2.ReadOptions(new_transaction=input_options) + assert read_options == expected + + +@pytest.mark.parametrize( + "args", + [ + (True, "id"), + (True, "id", None), + (True, None, "read_time"), + (True, None, None, "new"), + (False, "id", "read_time"), + (False, "id", None, "new"), + (False, None, "read_time", "new"), + ], +) +def test__get_read_options_w_multiple_args(args): + """ + arguments are mutually exclusive. + Should raise ValueError if multiple are set + """ + from google.cloud.datastore.helpers import get_read_options + + with pytest.raises(ValueError): + get_read_options(*args) + + +def test__get_transaction_options_none(): + """ + test with empty transaction input + """ + from google.cloud.datastore.helpers import get_transaction_options + + t_id, new_t = get_transaction_options(None) + assert t_id is None + assert new_t is None + + +def test__get_transaction_options_w_id(): + """ + test with transaction with id set + """ + from google.cloud.datastore.helpers import get_transaction_options + from google.cloud.datastore import Transaction + + expected_id = b"123abc" + txn = Transaction(None, begin_later=True) + txn._id = expected_id + t_id, new_t = get_transaction_options(txn) + assert t_id == expected_id + assert new_t is None + + +def test__get_transaction_options_w_begin_later(): + """ + if begin later is set and it hasn't begun, should return new_transaction_options + """ + from google.cloud.datastore.helpers import get_transaction_options + from google.cloud.datastore import Transaction + + txn = Transaction(None, begin_later=True) + t_id, new_t = get_transaction_options(txn) + assert t_id is None + assert new_t is txn._options + + +def test__get_transaction_options_not_started(): + """ + If the transaction is noet set as begin_later, but it hasn't begun, return None for both + """ + from google.cloud.datastore.helpers import get_transaction_options + from google.cloud.datastore import Transaction + + txn = Transaction(None, begin_later=False) + t_id, new_t = get_transaction_options(txn) + assert t_id is None + assert new_t is None + + def test__pb_attr_value_w_datetime_naive(): import calendar import datetime diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 84c0bedf..6c2063bb 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -667,7 +667,7 @@ def test_eventual_transaction_fails(database_id): @pytest.mark.parametrize("database_id", [None, "somedb"]) def test_transaction_id_populated(database_id): """ - When an aggregation is run in the context of a transaction, the transaction + When an query is run in the context of a transaction, the transaction ID should be populated in the request. """ import mock @@ -698,6 +698,47 @@ def test_transaction_id_populated(database_id): assert read_options.transaction == client.current_transaction.id +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_query_transaction_begin_later(database_id): + """ + When an aggregation is run in the context of a transaction with begin_later=True, + the new_transaction field should be populated in the request read_options. + """ + import mock + from google.cloud.datastore_v1.types import TransactionOptions + + # make a fake begin_later transaction + transaction = mock.Mock() + transaction.id = None + transaction._begin_later = True + transaction._status = transaction._INITIAL + transaction._options = TransactionOptions(read_only=TransactionOptions.ReadOnly()) + + mock_datastore_api = mock.Mock() + mock_gapic = mock_datastore_api.run_query + + more_results_enum = 3 # NO_MORE_RESULTS + response_pb = _make_query_response([], b"", more_results_enum, 0) + mock_gapic.return_value = response_pb + + client = _Client( + None, + datastore_api=mock_datastore_api, + database=database_id, + transaction=transaction, + ) + + query = _make_query(client) + # run mock query + list(query.fetch()) + assert mock_gapic.call_count == 1 + request = mock_gapic.call_args[1]["request"] + read_options = request["read_options"] + # ensure new_transaction is populated + assert not read_options.transaction + assert read_options.new_transaction == transaction._options + + def test_iterator_constructor_defaults(): query = object() client = object() @@ -885,7 +926,9 @@ def _next_page_helper( if txn_id is None: client = _Client(project, database=database, datastore_api=ds_api) else: - transaction = mock.Mock(id=txn_id, spec=["id"]) + transaction = mock.Mock( + id=txn_id, _begin_later=False, spec=["id", "_begin_later"] + ) client = _Client( project, database=database, datastore_api=ds_api, transaction=transaction ) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 23574ef4..cee384bb 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -81,6 +81,27 @@ def test_transaction_constructor_read_write_w_read_time(database_id): _make_transaction(client, read_only=False, read_time=read_time) +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_constructor_begin_later(database_id): + from google.cloud.datastore.transaction import Transaction + + project = "PROJECT" + client = _Client(project, database=database_id) + expected_id = b"1234" + + xact = _make_transaction(client, begin_later=True) + assert xact._status == Transaction._INITIAL + assert xact.id is None + + xact._begin_with_id(expected_id) + assert xact._status == Transaction._IN_PROGRESS + assert xact.id == expected_id + + # calling a second time should raise exeception + with pytest.raises(ValueError): + xact._begin_with_id(expected_id) + + @pytest.mark.parametrize("database_id", [None, "somedb"]) def test_transaction_current(database_id): from google.cloud.datastore_v1.types import datastore as datastore_pb2 @@ -375,6 +396,7 @@ def test_transaction_context_manager_no_raise(database_id): xact = _make_transaction(client) with xact: + assert xact._status == xact._IN_PROGRESS # only set between begin / commit assert xact.id == id_ @@ -427,6 +449,34 @@ class Foo(Exception): client._datastore_api.rollback.assert_called_once_with(request=expected_request) +@pytest.mark.parametrize("with_exception", [False, True]) +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_context_manager_w_begin_later(database_id, with_exception): + """ + If begin_later is set, don't begin transaction when entering context manager + """ + project = "PROJECT" + id_ = 912830 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api, database=database_id) + xact = _make_transaction(client, begin_later=True) + + try: + with xact: + assert xact._status == xact._INITIAL + assert xact.id is None + if with_exception: + raise RuntimeError("expected") + except RuntimeError: + pass + # should be finalized after context manager block + assert xact._status == xact._ABORTED + assert xact.id is None + # no need to call commit or rollback + assert ds_api.commit.call_count == 0 + assert ds_api.rollback.call_count == 0 + + @pytest.mark.parametrize("database_id", [None, "somedb"]) def test_transaction_put_read_only(database_id): project = "PROJECT" @@ -441,6 +491,100 @@ def test_transaction_put_read_only(database_id): xact.put(entity) +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_put_w_begin_later(database_id): + """ + If begin_later is set, should be able to call put without begin first + """ + project = "PROJECT" + id_ = 943243 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api, database=database_id) + entity = _Entity(database=database_id) + with _make_transaction(client, begin_later=True) as xact: + assert xact._status == xact._INITIAL + assert len(xact.mutations) == 0 + xact.put(entity) + assert len(xact.mutations) == 1 + # should still be in initial state + assert xact._status == xact._INITIAL + + +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_delete_w_begin_later(database_id): + """ + If begin_later is set, should be able to call delete without begin first + """ + project = "PROJECT" + id_ = 943243 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api, database=database_id) + entity = _Entity(database=database_id) + with _make_transaction(client, begin_later=True) as xact: + assert xact._status == xact._INITIAL + assert len(xact.mutations) == 0 + xact.delete(entity.key.completed_key("name")) + assert len(xact.mutations) == 1 + # should still be in initial state + assert xact._status == xact._INITIAL + + +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_rollback_no_begin(database_id): + """ + If rollback is called without begin, transaciton should abort + """ + project = "PROJECT" + id_ = 943243 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api, database=database_id) + with _make_transaction(client, begin_later=True) as xact: + assert xact._status == xact._INITIAL + with mock.patch.object(xact, "begin") as begin: + xact.rollback() + begin.assert_not_called() + assert xact._status == xact._ABORTED + + +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_commit_no_begin(database_id): + """ + If commit is called without begin, and it has mutations staged, + should call begin before commit + """ + project = "PROJECT" + id_ = 943243 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api, database=database_id) + entity = _Entity(database=database_id) + with _make_transaction(client, begin_later=True) as xact: + assert xact._status == xact._INITIAL + xact.put(entity) + assert xact._status == xact._INITIAL + with mock.patch.object(xact, "begin") as begin: + begin.side_effect = lambda: setattr(xact, "_status", xact._IN_PROGRESS) + xact.commit() + begin.assert_called_once_with() + + +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_empty_transaction_commit(database_id): + """ + If commit is called without begin, and it has no mutations staged, + should abort + """ + project = "PROJECT" + id_ = 943243 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api, database=database_id) + with _make_transaction(client, begin_later=True) as xact: + assert xact._status == xact._INITIAL + with mock.patch.object(xact, "begin") as begin: + xact.commit() + begin.assert_not_called() + assert xact._status == xact._ABORTED + + def _make_key(kind, id_, project, database=None): from google.cloud.datastore_v1.types import entity as entity_pb2