diff --git a/google/cloud/spanner_dbapi/batch_dml_executor.py b/google/cloud/spanner_dbapi/batch_dml_executor.py new file mode 100644 index 0000000000..f91cf37b59 --- /dev/null +++ b/google/cloud/spanner_dbapi/batch_dml_executor.py @@ -0,0 +1,131 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, List +from google.cloud.spanner_dbapi.checksum import ResultsChecksum +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + StatementType, + Statement, +) +from google.rpc.code_pb2 import ABORTED, OK +from google.api_core.exceptions import Aborted + +from google.cloud.spanner_dbapi.utils import StreamedManyResultSets + +if TYPE_CHECKING: + from google.cloud.spanner_dbapi.cursor import Cursor + + +class BatchDmlExecutor: + """Executor that is used when a DML batch is started. These batches only + accept DML statements. All DML statements are buffered locally and sent to + Spanner when runBatch() is called. + + :type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor` + :param cursor: + """ + + def __init__(self, cursor: "Cursor"): + self._cursor = cursor + self._connection = cursor.connection + self._statements: List[Statement] = [] + + def execute_statement(self, parsed_statement: ParsedStatement): + """Executes the statement when dml batch is active by buffering the + statement in-memory. + + :type parsed_statement: ParsedStatement + :param parsed_statement: parsed statement containing sql query and query + params + """ + from google.cloud.spanner_dbapi import ProgrammingError + + if ( + parsed_statement.statement_type != StatementType.UPDATE + and parsed_statement.statement_type != StatementType.INSERT + ): + raise ProgrammingError("Only DML statements are allowed in batch DML mode.") + self._statements.append(parsed_statement.statement) + + def run_batch_dml(self): + """Executes all the buffered statements on the active dml batch by + making a call to Spanner. + """ + return run_batch_dml(self._cursor, self._statements) + + +def run_batch_dml(cursor: "Cursor", statements: List[Statement]): + """Executes all the dml statements by making a batch call to Spanner. + + :type cursor: Cursor + :param cursor: Database Cursor object + + :type statements: List[Statement] + :param statements: list of statements to execute in batch + """ + from google.cloud.spanner_dbapi import OperationalError + + connection = cursor.connection + many_result_set = StreamedManyResultSets() + statements_tuple = [] + for statement in statements: + statements_tuple.append(statement.get_tuple()) + if not connection._client_transaction_started: + res = connection.database.run_in_transaction(_do_batch_update, statements_tuple) + many_result_set.add_iter(res) + cursor._row_count = sum([max(val, 0) for val in res]) + else: + retried = False + while True: + try: + transaction = connection.transaction_checkout() + status, res = transaction.batch_update(statements_tuple) + many_result_set.add_iter(res) + res_checksum = ResultsChecksum() + res_checksum.consume_result(res) + res_checksum.consume_result(status.code) + if not retried: + connection._statements.append((statements, res_checksum)) + cursor._row_count = sum([max(val, 0) for val in res]) + + if status.code == ABORTED: + connection._transaction = None + raise Aborted(status.message) + elif status.code != OK: + raise OperationalError(status.message) + return many_result_set + except Aborted: + connection.retry_transaction() + retried = True + + +def _do_batch_update(transaction, statements): + from google.cloud.spanner_dbapi import OperationalError + + status, res = transaction.batch_update(statements) + if status.code == ABORTED: + raise Aborted(status.message) + elif status.code != OK: + raise OperationalError(status.message) + return res + + +class BatchMode(Enum): + DML = 1 + DDL = 2 + NONE = 3 diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 2d8eeed4a5..06d0d25948 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi import ProgrammingError from google.cloud.spanner_dbapi.parsed_statement import ( @@ -38,17 +38,18 @@ ) -def execute(connection: "Connection", parsed_statement: ParsedStatement): +def execute(cursor: "Cursor", parsed_statement: ParsedStatement): """Executes the client side statements by calling the relevant method. It is an internal method that can make backwards-incompatible changes. - :type connection: Connection - :param connection: Connection object of the dbApi + :type cursor: Cursor + :param cursor: Cursor object of the dbApi :type parsed_statement: ParsedStatement :param parsed_statement: parsed_statement based on the sql query """ + connection = cursor.connection if connection.is_closed: raise ProgrammingError(CONNECTION_CLOSED_ERROR) statement_type = parsed_statement.client_side_statement_type @@ -81,6 +82,13 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement): TypeCode.TIMESTAMP, read_timestamp, ) + if statement_type == ClientSideStatementType.START_BATCH_DML: + connection.start_batch_dml(cursor) + return None + if statement_type == ClientSideStatementType.RUN_BATCH: + return connection.run_batch() + if statement_type == ClientSideStatementType.ABORT_BATCH: + return connection.abort_batch() def _get_streamed_result_set(column_name, type_code, column_value): diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index 35d0e4e609..39970259b2 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -18,6 +18,7 @@ ParsedStatement, StatementType, ClientSideStatementType, + Statement, ) RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE) @@ -29,6 +30,9 @@ RE_SHOW_READ_TIMESTAMP = re.compile( r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE ) +RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE) +RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE) +RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE) def parse_stmt(query): @@ -54,8 +58,14 @@ def parse_stmt(query): client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP if RE_SHOW_READ_TIMESTAMP.match(query): client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP + if RE_START_BATCH_DML.match(query): + client_side_statement_type = ClientSideStatementType.START_BATCH_DML + if RE_RUN_BATCH.match(query): + client_side_statement_type = ClientSideStatementType.RUN_BATCH + if RE_ABORT_BATCH.match(query): + client_side_statement_type = ClientSideStatementType.ABORT_BATCH if client_side_statement_type is not None: return ParsedStatement( - StatementType.CLIENT_SIDE, query, client_side_statement_type + StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type ) return None diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index f60913fd14..e635563587 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -13,13 +13,14 @@ # limitations under the License. """DB-API Connection for the Google Cloud Spanner.""" - import time import warnings from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor +from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot @@ -28,7 +29,11 @@ from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor -from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError +from google.cloud.spanner_dbapi.exceptions import ( + InterfaceError, + OperationalError, + ProgrammingError, +) from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT from google.cloud.spanner_dbapi.version import PY_VERSION @@ -111,6 +116,8 @@ def __init__(self, instance, database=None, read_only=False): # whether transaction started at Spanner. This means that we had # made atleast one call to Spanner. self._spanner_transaction_started = False + self._batch_mode = BatchMode.NONE + self._batch_dml_executor: BatchDmlExecutor = None @property def autocommit(self): @@ -310,7 +317,10 @@ def _rerun_previous_statements(self): statements, checksum = statement transaction = self.transaction_checkout() - status, res = transaction.batch_update(statements) + statements_tuple = [] + for single_statement in statements: + statements_tuple.append(single_statement.get_tuple()) + status, res = transaction.batch_update(statements_tuple) if status.code == ABORTED: raise Aborted(status.details) @@ -476,14 +486,14 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() - def run_statement(self, statement, retried=False): + def run_statement(self, statement: Statement, retried=False): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In !autocommit mode however it remembers every executed SQL statement with its parameters. - :type statement: :class:`dict` + :type statement: :class:`Statement` :param statement: SQL statement to execute. :type retried: bool @@ -534,6 +544,47 @@ def validate(self): "Expected: [[1]]" % result ) + @check_not_closed + def start_batch_dml(self, cursor): + if self._batch_mode is not BatchMode.NONE: + raise ProgrammingError( + "Cannot start a DML batch when a batch is already active" + ) + if self.read_only: + raise ProgrammingError( + "Cannot start a DML batch when the connection is in read-only mode" + ) + self._batch_mode = BatchMode.DML + self._batch_dml_executor = BatchDmlExecutor(cursor) + + @check_not_closed + def execute_batch_dml_statement(self, parsed_statement: ParsedStatement): + if self._batch_mode is not BatchMode.DML: + raise ProgrammingError( + "Cannot execute statement when the BatchMode is not DML" + ) + self._batch_dml_executor.execute_statement(parsed_statement) + + @check_not_closed + def run_batch(self): + if self._batch_mode is BatchMode.NONE: + raise ProgrammingError("Cannot run a batch when the BatchMode is not set") + try: + if self._batch_mode is BatchMode.DML: + many_result_set = self._batch_dml_executor.run_batch_dml() + finally: + self._batch_mode = BatchMode.NONE + self._batch_dml_executor = None + return many_result_set + + @check_not_closed + def abort_batch(self): + if self._batch_mode is BatchMode.NONE: + raise ProgrammingError("Cannot abort a batch when the BatchMode is not set") + if self._batch_mode is BatchMode.DML: + self._batch_dml_executor = None + self._batch_mode = BatchMode.NONE + def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 726dd26cb4..ff91e9e666 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -26,29 +26,33 @@ from google.api_core.exceptions import OutOfRange from google.cloud import spanner_v1 as spanner -from google.cloud.spanner_dbapi.checksum import ResultsChecksum +from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode from google.cloud.spanner_dbapi.exceptions import IntegrityError from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.exceptions import OperationalError from google.cloud.spanner_dbapi.exceptions import ProgrammingError -from google.cloud.spanner_dbapi import _helpers, client_side_statement_executor +from google.cloud.spanner_dbapi import ( + _helpers, + client_side_statement_executor, + batch_dml_executor, +) from google.cloud.spanner_dbapi._helpers import ColumnInfo from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types -from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner -from google.cloud.spanner_dbapi.parsed_statement import StatementType +from google.cloud.spanner_dbapi.parsed_statement import ( + StatementType, + Statement, + ParsedStatement, +) from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets -from google.rpc.code_pb2 import ABORTED, OK - _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) -Statement = namedtuple("Statement", "sql, params, param_types, checksum") def check_not_closed(function): @@ -188,17 +192,6 @@ def _do_execute_update_in_autocommit(self, transaction, sql, params): self._itr = PeekIterator(self._result_set) self._row_count = _UNSET_COUNT - def _do_batch_update(self, transaction, statements, many_result_set): - status, res = transaction.batch_update(statements) - many_result_set.add_iter(res) - - if status.code == ABORTED: - raise Aborted(status.message) - elif status.code != OK: - raise OperationalError(status.message) - - self._row_count = sum([max(val, 0) for val in res]) - def _batch_DDLs(self, sql): """ Check that the given operation contains only DDL @@ -242,14 +235,20 @@ def execute(self, sql, args=None): self._row_count = _UNSET_COUNT try: - parsed_statement = parse_utils.classify_statement(sql) - + parsed_statement: ParsedStatement = parse_utils.classify_statement( + sql, args + ) if parsed_statement.statement_type == StatementType.CLIENT_SIDE: self._result_set = client_side_statement_executor.execute( - self.connection, parsed_statement + self, parsed_statement ) if self._result_set is not None: - self._itr = PeekIterator(self._result_set) + if isinstance(self._result_set, StreamedManyResultSets): + self._itr = self._result_set + else: + self._itr = PeekIterator(self._result_set) + elif self.connection._batch_mode == BatchMode.DML: + self.connection.execute_batch_dml_statement(parsed_statement) elif self.connection.read_only or ( not self.connection._client_transaction_started and parsed_statement.statement_type == StatementType.QUERY @@ -260,7 +259,7 @@ def execute(self, sql, args=None): if not self.connection._client_transaction_started: self.connection.run_prior_DDL_statements() else: - self._execute_in_rw_transaction(parsed_statement, sql, args) + self._execute_in_rw_transaction(parsed_statement) except (AlreadyExists, FailedPrecondition, OutOfRange) as e: raise IntegrityError(getattr(e, "details", e)) from e @@ -272,26 +271,15 @@ def execute(self, sql, args=None): if self.connection._client_transaction_started is False: self.connection._spanner_transaction_started = False - def _execute_in_rw_transaction(self, parsed_statement, sql, args): + def _execute_in_rw_transaction(self, parsed_statement: ParsedStatement): # For every other operation, we've got to ensure that # any prior DDL statements were run. self.connection.run_prior_DDL_statements() - if parsed_statement.statement_type == StatementType.UPDATE: - sql = parse_utils.ensure_where_clause(sql) - sql, args = sql_pyformat_args_to_spanner(sql, args or None) - if self.connection._client_transaction_started: - statement = Statement( - sql, - args, - get_param_types(args or None), - ResultsChecksum(), - ) - ( self._result_set, self._checksum, - ) = self.connection.run_statement(statement) + ) = self.connection.run_statement(parsed_statement.statement) while True: try: @@ -300,13 +288,13 @@ def _execute_in_rw_transaction(self, parsed_statement, sql, args): except Aborted: self.connection.retry_transaction() except Exception as ex: - self.connection._statements.remove(statement) + self.connection._statements.remove(parsed_statement.statement) raise ex else: self.connection.database.run_in_transaction( self._do_execute_update_in_autocommit, - sql, - args or None, + parsed_statement.statement.sql, + parsed_statement.statement.params or None, ) @check_not_closed @@ -343,56 +331,19 @@ def executemany(self, operation, seq_of_params): # For every operation, we've got to ensure that any prior DDL # statements were run. self.connection.run_prior_DDL_statements() - - many_result_set = StreamedManyResultSets() - if parsed_statement.statement_type in ( StatementType.INSERT, StatementType.UPDATE, ): statements = [] - for params in seq_of_params: sql, params = parse_utils.sql_pyformat_args_to_spanner( operation, params ) - statements.append((sql, params, get_param_types(params))) - - if not self.connection._client_transaction_started: - self.connection.database.run_in_transaction( - self._do_batch_update, statements, many_result_set - ) - else: - retried = False - total_row_count = 0 - while True: - try: - transaction = self.connection.transaction_checkout() - - res_checksum = ResultsChecksum() - if not retried: - self.connection._statements.append( - (statements, res_checksum) - ) - - status, res = transaction.batch_update(statements) - many_result_set.add_iter(res) - res_checksum.consume_result(res) - res_checksum.consume_result(status.code) - total_row_count += sum([max(val, 0) for val in res]) - - if status.code == ABORTED: - self.connection._transaction = None - raise Aborted(status.message) - elif status.code != OK: - raise OperationalError(status.message) - self._row_count = total_row_count - break - except Aborted: - self.connection.retry_transaction() - retried = True - + statements.append(Statement(sql, params, get_param_types(params))) + many_result_set = batch_dml_executor.run_batch_dml(self, statements) else: + many_result_set = StreamedManyResultSets() for params in seq_of_params: self.execute(operation, params) many_result_set.add_iter(self._itr) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 97276e54f6..76ac951e0c 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -24,8 +24,9 @@ from . import client_side_statement_parser from deprecated import deprecated +from .checksum import ResultsChecksum from .exceptions import Error -from .parsed_statement import ParsedStatement, StatementType +from .parsed_statement import ParsedStatement, StatementType, Statement from .types import DateStr, TimestampStr from .utils import sanitize_literals_for_upload @@ -205,7 +206,7 @@ def classify_stmt(query): return STMT_UPDATING -def classify_statement(query): +def classify_statement(query, args=None): """Determine SQL query type. It is an internal method that can make backwards-incompatible changes. @@ -221,21 +222,29 @@ def classify_statement(query): # PostgreSQL dollar quoted comments are not # supported and will not be stripped. query = sqlparse.format(query, strip_comments=True).strip() - parsed_statement = client_side_statement_parser.parse_stmt(query) + parsed_statement: ParsedStatement = client_side_statement_parser.parse_stmt(query) if parsed_statement is not None: return parsed_statement + query, args = sql_pyformat_args_to_spanner(query, args or None) + statement = Statement( + query, + args, + get_param_types(args or None), + ResultsChecksum(), + ) if RE_DDL.match(query): - return ParsedStatement(StatementType.DDL, query) + return ParsedStatement(StatementType.DDL, statement) if RE_IS_INSERT.match(query): - return ParsedStatement(StatementType.INSERT, query) + return ParsedStatement(StatementType.INSERT, statement) if RE_NON_UPDATE.match(query) or RE_WITH.match(query): # As of 13-March-2020, Cloud Spanner only supports WITH for DQL # statements and doesn't yet support WITH for DML statements. - return ParsedStatement(StatementType.QUERY, query) + return ParsedStatement(StatementType.QUERY, statement) - return ParsedStatement(StatementType.UPDATE, query) + statement.sql = ensure_where_clause(query) + return ParsedStatement(StatementType.UPDATE, statement) def sql_pyformat_args_to_spanner(sql, params): diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 30f4c1630f..4f633c7b10 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from dataclasses import dataclass from enum import Enum +from typing import Any + +from google.cloud.spanner_dbapi.checksum import ResultsChecksum class StatementType(Enum): @@ -30,10 +32,24 @@ class ClientSideStatementType(Enum): ROLLBACK = 3 SHOW_COMMIT_TIMESTAMP = 4 SHOW_READ_TIMESTAMP = 5 + START_BATCH_DML = 6 + RUN_BATCH = 7 + ABORT_BATCH = 8 + + +@dataclass +class Statement: + sql: str + params: Any = None + param_types: Any = None + checksum: ResultsChecksum = None + + def get_tuple(self): + return self.sql, self.params, self.param_types @dataclass class ParsedStatement: statement_type: StatementType - query: str + statement: Statement client_side_statement_type: ClientSideStatementType = None diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 6a6cc385f6..fdea0b0d17 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -425,6 +425,125 @@ def test_read_timestamp_client_side_autocommit(self): read_timestamp_query_result_2 = self._cursor.fetchall() assert read_timestamp_query_result_1 != read_timestamp_query_result_2 + @pytest.mark.parametrize("auto_commit", [False, True]) + def test_batch_dml(self, auto_commit): + """Test batch dml.""" + + if auto_commit: + self._conn.autocommit = True + self._insert_row(1) + + self._cursor.execute("start batch dml") + self._insert_row(2) + self._insert_row(3) + self._cursor.execute("run batch") + + self._insert_row(4) + + # Test starting another dml batch in same transaction works + self._cursor.execute("start batch dml") + self._insert_row(5) + self._insert_row(6) + self._cursor.execute("run batch") + + if not auto_commit: + self._conn.commit() + + self._cursor.execute("SELECT * FROM contacts") + assert ( + self._cursor.fetchall().sort() + == ( + [ + (1, "first-name-1", "last-name-1", "test.email@domen.ru"), + (2, "first-name-2", "last-name-2", "test.email@domen.ru"), + (3, "first-name-3", "last-name-3", "test.email@domen.ru"), + (4, "first-name-4", "last-name-4", "test.email@domen.ru"), + (5, "first-name-5", "last-name-5", "test.email@domen.ru"), + (6, "first-name-6", "last-name-6", "test.email@domen.ru"), + ] + ).sort() + ) + + # Test starting another dml batch in same connection post commit works + self._cursor.execute("start batch dml") + self._insert_row(7) + self._insert_row(8) + self._cursor.execute("run batch") + + self._insert_row(9) + + if not auto_commit: + self._conn.commit() + + self._cursor.execute("SELECT * FROM contacts") + assert len(self._cursor.fetchall()) == 9 + + def test_abort_batch_dml(self): + """Test abort batch dml.""" + + self._cursor.execute("start batch dml") + self._insert_row(1) + self._insert_row(2) + self._cursor.execute("abort batch") + + self._insert_row(3) + self._conn.commit() + + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 1 + assert got_rows == [(3, "first-name-3", "last-name-3", "test.email@domen.ru")] + + def test_batch_dml_invalid_statements(self): + """Test batch dml having invalid statements.""" + + # Test first statement in batch is invalid + self._cursor.execute("start batch dml") + self._cursor.execute( + """ + INSERT INTO unknown_table (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._insert_row(1) + self._insert_row(2) + with pytest.raises(OperationalError): + self._cursor.execute("run batch") + + # Test middle statement in batch is invalid + self._cursor.execute("start batch dml") + self._insert_row(1) + self._cursor.execute( + """ + INSERT INTO unknown_table (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._insert_row(2) + with pytest.raises(OperationalError): + self._cursor.execute("run batch") + + # Test last statement in batch is invalid + self._cursor.execute("start batch dml") + self._insert_row(1) + self._insert_row(2) + self._cursor.execute( + """ + INSERT INTO unknown_table (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + with pytest.raises(OperationalError): + self._cursor.execute("run batch") + + def _insert_row(self, i): + self._cursor.execute( + f""" + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES ({i}, 'first-name-{i}', 'last-name-{i}', 'test.email@domen.ru') + """ + ) + def test_begin_success_post_commit(self): """Test beginning a new transaction post commiting an existing transaction is possible on a connection, when connection is in autocommit mode.""" diff --git a/tests/unit/spanner_dbapi/test_batch_dml_executor.py b/tests/unit/spanner_dbapi/test_batch_dml_executor.py new file mode 100644 index 0000000000..3dc387bcb6 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_batch_dml_executor.py @@ -0,0 +1,54 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from google.cloud.spanner_dbapi import ProgrammingError +from google.cloud.spanner_dbapi.batch_dml_executor import BatchDmlExecutor +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + Statement, + StatementType, +) + + +class TestBatchDmlExecutor(unittest.TestCase): + @mock.patch("google.cloud.spanner_dbapi.cursor.Cursor") + def setUp(self, mock_cursor): + self._under_test = BatchDmlExecutor(mock_cursor) + + def test_execute_statement_non_dml_statement_type(self): + parsed_statement = ParsedStatement(StatementType.QUERY, Statement("sql")) + + with self.assertRaises(ProgrammingError): + self._under_test.execute_statement(parsed_statement) + + def test_execute_statement_insert_statement_type(self): + statement = Statement("sql") + + self._under_test.execute_statement( + ParsedStatement(StatementType.INSERT, statement) + ) + + self.assertEqual(self._under_test._statements, [statement]) + + def test_execute_statement_update_statement_type(self): + statement = Statement("sql") + + self._under_test.execute_statement( + ParsedStatement(StatementType.UPDATE, statement) + ) + + self.assertEqual(self._under_test._statements, [statement]) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 853b78a936..de028c3206 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -19,9 +19,20 @@ import unittest import warnings import pytest -from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError + +from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode +from google.cloud.spanner_dbapi.exceptions import ( + InterfaceError, + OperationalError, + ProgrammingError, +) from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_dbapi.connection import CLIENT_TRANSACTION_NOT_STARTED_WARNING +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + StatementType, + Statement, +) PROJECT = "test-project" INSTANCE = "test-instance" @@ -332,6 +343,94 @@ def test_rollback_in_autocommit_mode(self, mock_warn): CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 ) + def test_start_batch_dml_batch_mode_active(self): + self._under_test._batch_mode = BatchMode.DML + cursor = self._under_test.cursor() + + with self.assertRaises(ProgrammingError): + self._under_test.start_batch_dml(cursor) + + def test_start_batch_dml_connection_read_only(self): + self._under_test.read_only = True + cursor = self._under_test.cursor() + + with self.assertRaises(ProgrammingError): + self._under_test.start_batch_dml(cursor) + + def test_start_batch_dml(self): + cursor = self._under_test.cursor() + + self._under_test.start_batch_dml(cursor) + + self.assertEqual(self._under_test._batch_mode, BatchMode.DML) + + def test_execute_batch_dml_batch_mode_inactive(self): + self._under_test._batch_mode = BatchMode.NONE + + with self.assertRaises(ProgrammingError): + self._under_test.execute_batch_dml_statement( + ParsedStatement(StatementType.UPDATE, Statement("sql")) + ) + + @mock.patch( + "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + ) + def test_execute_batch_dml(self, mock_batch_dml_executor): + self._under_test._batch_mode = BatchMode.DML + self._under_test._batch_dml_executor = mock_batch_dml_executor + + parsed_statement = ParsedStatement(StatementType.UPDATE, Statement("sql")) + self._under_test.execute_batch_dml_statement(parsed_statement) + + mock_batch_dml_executor.execute_statement.assert_called_once_with( + parsed_statement + ) + + @mock.patch( + "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + ) + def test_run_batch_batch_mode_inactive(self, mock_batch_dml_executor): + self._under_test._batch_mode = BatchMode.NONE + self._under_test._batch_dml_executor = mock_batch_dml_executor + + with self.assertRaises(ProgrammingError): + self._under_test.run_batch() + + @mock.patch( + "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + ) + def test_run_batch(self, mock_batch_dml_executor): + self._under_test._batch_mode = BatchMode.DML + self._under_test._batch_dml_executor = mock_batch_dml_executor + + self._under_test.run_batch() + + mock_batch_dml_executor.run_batch_dml.assert_called_once_with() + self.assertEqual(self._under_test._batch_mode, BatchMode.NONE) + self.assertEqual(self._under_test._batch_dml_executor, None) + + @mock.patch( + "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + ) + def test_abort_batch_batch_mode_inactive(self, mock_batch_dml_executor): + self._under_test._batch_mode = BatchMode.NONE + self._under_test._batch_dml_executor = mock_batch_dml_executor + + with self.assertRaises(ProgrammingError): + self._under_test.abort_batch() + + @mock.patch( + "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + ) + def test_abort_dml_batch(self, mock_batch_dml_executor): + self._under_test._batch_mode = BatchMode.DML + self._under_test._batch_dml_executor = mock_batch_dml_executor + + self._under_test.abort_batch() + + self.assertEqual(self._under_test._batch_mode, BatchMode.NONE) + self.assertEqual(self._under_test._batch_dml_executor, None) + @mock.patch("google.cloud.spanner_v1.database.Database", autospec=True) def test_run_prior_DDL_statements(self, mock_database): from google.cloud.spanner_dbapi import Connection, InterfaceError @@ -396,7 +495,7 @@ def test_begin(self): def test_run_statement_wo_retried(self): """Check that Connection remembers executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement sql = """SELECT 23 FROM table WHERE id = @a1""" params = {"a1": "value"} @@ -415,7 +514,7 @@ def test_run_statement_wo_retried(self): def test_run_statement_w_retried(self): """Check that Connection doesn't remember re-executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement sql = """SELECT 23 FROM table WHERE id = @a1""" params = {"a1": "value"} @@ -431,7 +530,7 @@ def test_run_statement_w_retried(self): def test_run_statement_w_heterogenous_insert_statements(self): """Check that Connection executed heterogenous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement from google.rpc.status_pb2 import Status from google.rpc.code_pb2 import OK @@ -452,7 +551,7 @@ def test_run_statement_w_heterogenous_insert_statements(self): def test_run_statement_w_homogeneous_insert_statements(self): """Check that Connection executed homogeneous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement from google.rpc.status_pb2 import Status from google.rpc.code_pb2 import OK @@ -507,7 +606,7 @@ def test_rollback_clears_statements(self, mock_transaction): def test_retry_transaction_w_checksum_match(self): """Check retrying an aborted transaction.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = ["field1", "field2"] connection = self._make_connection() @@ -536,7 +635,7 @@ def test_retry_transaction_w_checksum_mismatch(self): """ from google.cloud.spanner_dbapi.exceptions import RetryAborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = ["field1", "field2"] retried_row = ["field3", "field4"] @@ -560,7 +659,7 @@ def test_commit_retry_aborted_statements(self, mock_client): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = ["field1", "field2"] @@ -592,7 +691,7 @@ def test_retry_aborted_retry(self, mock_client): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = ["field1", "field2"] @@ -625,7 +724,7 @@ def test_retry_transaction_raise_max_internal_retries(self): """Check retrying raise an error of max internal retries.""" from google.cloud.spanner_dbapi import connection as conn from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement conn.MAX_INTERNAL_RETRIES = 0 row = ["field1", "field2"] @@ -651,7 +750,7 @@ def test_retry_aborted_retry_without_delay(self, mock_client): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = ["field1", "field2"] @@ -684,7 +783,7 @@ def test_retry_aborted_retry_without_delay(self, mock_client): def test_retry_transaction_w_multiple_statement(self): """Check retrying an aborted transaction.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = ["field1", "field2"] connection = self._make_connection() @@ -712,7 +811,7 @@ def test_retry_transaction_w_multiple_statement(self): def test_retry_transaction_w_empty_response(self): """Check retrying an aborted transaction.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement row = [] connection = self._make_connection() @@ -927,7 +1026,7 @@ def test_staleness_single_use_readonly_autocommit(self, MockedPeekIterator): def test_request_priority(self): from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_dbapi.parsed_statement import Statement from google.cloud.spanner_v1 import RequestOptions sql = "SELECT 1" diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index dfa0a0ac17..3328b0e17f 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -17,7 +17,11 @@ import sys import unittest -from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, StatementType +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + StatementType, + Statement, +) class TestCursor(unittest.TestCase): @@ -213,8 +217,8 @@ def test_execute_statement(self): with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=[ - ParsedStatement(StatementType.DDL, sql), - ParsedStatement(StatementType.UPDATE, sql), + ParsedStatement(StatementType.DDL, Statement(sql)), + ParsedStatement(StatementType.UPDATE, Statement(sql)), ], ) as mockclassify_statement: with self.assertRaises(ValueError): @@ -225,7 +229,7 @@ def test_execute_statement(self): with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", - return_value=ParsedStatement(StatementType.DDL, sql), + return_value=ParsedStatement(StatementType.DDL, Statement(sql)), ) as mockclassify_statement: sql = "sql" cursor.execute(sql=sql) @@ -235,11 +239,11 @@ def test_execute_statement(self): with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", - return_value=ParsedStatement(StatementType.QUERY, sql), + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), ): with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", - return_value=ParsedStatement(StatementType.QUERY, sql), + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), ) as mock_handle_ddl: connection.autocommit = True sql = "sql" @@ -248,13 +252,13 @@ def test_execute_statement(self): with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", - return_value=ParsedStatement(StatementType.UPDATE, sql), + return_value=ParsedStatement(StatementType.UPDATE, Statement(sql)), ): cursor.connection._database = mock_db = mock.MagicMock() mock_db.run_in_transaction = mock_run_in = mock.MagicMock() cursor.execute(sql="sql") mock_run_in.assert_called_once_with( - cursor._do_execute_update_in_autocommit, "sql WHERE 1=1", None + cursor._do_execute_update_in_autocommit, "sql", None ) def test_execute_integrity_error(self): @@ -618,12 +622,12 @@ def test_executemany_insert_batch_aborted(self): self.assertEqual( connection._statements[0][0], [ - ( + Statement( """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, ), - ( + Statement( """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64},