diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a397028287..6438605d3b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -22,6 +22,9 @@ from google.cloud import spanner_v1 as spanner from google.cloud.spanner_v1.session import _get_retry_delay +from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous +from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous +from google.cloud.spanner_dbapi._helpers import parse_insert from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor @@ -82,7 +85,7 @@ def autocommit(self, value): :type value: bool :param value: New autocommit mode state. """ - if value and not self._autocommit: + if value and not self._autocommit and self.inside_transaction: self.commit() self._autocommit = value @@ -96,6 +99,19 @@ def database(self): """ return self._database + @property + def inside_transaction(self): + """Flag: transaction is started. + + Returns: + bool: True if transaction begun, False otherwise. + """ + return ( + self._transaction + and not self._transaction.committed + and not self._transaction.rolled_back + ) + @property def instance(self): """Instance to which this connection relates. @@ -191,11 +207,7 @@ def transaction_checkout(self): :returns: A Cloud Spanner transaction object, ready to use. """ if not self.autocommit: - if ( - not self._transaction - or self._transaction.committed - or self._transaction.rolled_back - ): + if not self.inside_transaction: self._transaction = self._session_checkout().transaction() self._transaction.begin() @@ -216,11 +228,7 @@ def close(self): The connection will be unusable from this point forward. If the connection has an active transaction, it will be rolled back. """ - if ( - self._transaction - and not self._transaction.committed - and not self._transaction.rolled_back - ): + if self.inside_transaction: self._transaction.rollback() if self._own_pool: @@ -235,7 +243,7 @@ def commit(self): """ if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) - elif self._transaction: + elif self.inside_transaction: try: self._transaction.commit() self._release_session() @@ -291,6 +299,24 @@ def run_statement(self, statement, retried=False): if not retried: self._statements.append(statement) + if statement.is_insert: + parts = parse_insert(statement.sql, statement.params) + + if parts.get("homogenous"): + _execute_insert_homogenous(transaction, parts) + return ( + iter(()), + ResultsChecksum() if retried else statement.checksum, + ) + else: + _execute_insert_heterogenous( + transaction, parts.get("sql_params_list"), + ) + return ( + iter(()), + ResultsChecksum() if retried else statement.checksum, + ) + return ( transaction.execute_sql( statement.sql, statement.params, param_types=statement.param_types, diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 363c2c653c..254eb5734a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -42,7 +42,7 @@ _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) -Statement = namedtuple("Statement", "sql, params, param_types, checksum") +Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert") class Cursor(object): @@ -95,9 +95,9 @@ def description(self): for field in row_type.fields: column_info = ColumnInfo( name=field.name, - type_code=field.type.code, + type_code=field.type_.code, # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), + display_size=code_to_display_size.get(field.type_.code), # Client perceived size of the column. internal_size=field.ByteSize(), ) @@ -172,10 +172,20 @@ def execute(self, sql, args=None): self.connection.run_prior_DDL_statements() if not self.connection.autocommit: - sql, params = sql_pyformat_args_to_spanner(sql, args) + if classification == parse_utils.STMT_UPDATING: + sql = parse_utils.ensure_where_clause(sql) + + if classification != parse_utils.STMT_INSERT: + sql, args = sql_pyformat_args_to_spanner(sql, args or None) statement = Statement( - sql, params, get_param_types(params), ResultsChecksum(), + sql, + args, + get_param_types(args or None) + if classification != parse_utils.STMT_INSERT + else {}, + ResultsChecksum(), + classification == parse_utils.STMT_INSERT, ) (self._result_set, self._checksum,) = self.connection.run_statement( statement @@ -233,7 +243,8 @@ def fetchone(self): try: res = next(self) - self._checksum.consume_result(res) + if not self.connection.autocommit: + self._checksum.consume_result(res) return res except StopIteration: return @@ -250,7 +261,8 @@ def fetchall(self): res = [] try: for row in self: - self._checksum.consume_result(row) + if not self.connection.autocommit: + self._checksum.consume_result(row) res.append(row) except Aborted: self._connection.retry_transaction() @@ -278,7 +290,8 @@ def fetchmany(self, size=None): for i in range(size): try: res = next(self) - self._checksum.consume_result(res) + if not self.connection.autocommit: + self._checksum.consume_result(res) items.append(res) except StopIteration: break diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 8848233d45..d3dd98dda6 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -523,19 +523,15 @@ def get_param_types(params): def ensure_where_clause(sql): """ Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Raise an error, if the given sql doesn't include it. + Add a dummy WHERE clause if non detected. :type sql: `str` :param sql: SQL code to check. - - :raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause. """ if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): return sql - raise ProgrammingError( - "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" - ) + return sql + " WHERE 1=1" def escape_name(name): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 213eb24d84..a338055a2c 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -15,7 +15,6 @@ """Cloud Spanner DB-API Connection class unit tests.""" import mock -import sys import unittest import warnings @@ -51,25 +50,57 @@ def _make_connection(self): database = instance.database(self.DATABASE) return Connection(instance, database) - @unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated") - def test_property_autocommit_setter(self): - from google.cloud.spanner_dbapi import Connection - - connection = Connection(self.INSTANCE, self.DATABASE) + def test_autocommit_setter_transaction_not_started(self): + connection = self._make_connection() with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.commit" ) as mock_commit: connection.autocommit = True - mock_commit.assert_called_once_with() - self.assertEqual(connection._autocommit, True) + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.commit" ) as mock_commit: connection.autocommit = False mock_commit.assert_not_called() - self.assertEqual(connection._autocommit, False) + self.assertFalse(connection._autocommit) + + def test_autocommit_setter_transaction_started(self): + connection = self._make_connection() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection._transaction = mock.Mock(committed=False, rolled_back=False) + + connection.autocommit = True + mock_commit.assert_called_once() + self.assertTrue(connection._autocommit) + + def test_autocommit_setter_transaction_started_commited_rolled_back(self): + connection = self._make_connection() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection._transaction = mock.Mock(committed=True, rolled_back=False) + + connection.autocommit = True + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) + + connection.autocommit = False + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection._transaction = mock.Mock(committed=False, rolled_back=True) + + connection.autocommit = True + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) def test_property_database(self): from google.cloud.spanner_v1.database import Database @@ -166,7 +197,9 @@ def test_commit(self, mock_warn): connection.commit() mock_release.assert_not_called() - connection._transaction = mock_transaction = mock.MagicMock() + connection._transaction = mock_transaction = mock.MagicMock( + rolled_back=False, committed=False + ) mock_transaction.commit = mock_commit = mock.MagicMock() with mock.patch( @@ -316,7 +349,7 @@ def test_run_statement_remember_statements(self): connection = self._make_connection() - statement = Statement(sql, params, param_types, ResultsChecksum(),) + statement = Statement(sql, params, param_types, ResultsChecksum(), False) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): @@ -338,7 +371,7 @@ def test_run_statement_dont_remember_retried_statements(self): connection = self._make_connection() - statement = Statement(sql, params, param_types, ResultsChecksum(),) + statement = Statement(sql, params, param_types, ResultsChecksum(), False) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): @@ -352,7 +385,7 @@ def test_clear_statements_on_commit(self): cleared, when the transaction is commited. """ connection = self._make_connection() - connection._transaction = mock.Mock() + connection._transaction = mock.Mock(rolled_back=False, committed=False) connection._statements = [{}, {}] self.assertEqual(len(connection._statements), 2) @@ -390,7 +423,7 @@ def test_retry_transaction(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement("SELECT 1", [], {}, checksum,) + statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( @@ -423,7 +456,7 @@ def test_retry_transaction_checksum_mismatch(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement("SELECT 1", [], {}, checksum,) + statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( @@ -453,9 +486,9 @@ def test_commit_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) - connection._transaction = mock.Mock() + connection._transaction = mock.Mock(rolled_back=False, committed=False) with mock.patch.object( connection._transaction, "commit", side_effect=(Aborted("Aborted"), None), @@ -507,7 +540,7 @@ def test_retry_aborted_retry(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) metadata_mock = mock.Mock() diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 81b290c4f1..9f0510c4ab 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -126,7 +126,7 @@ def test_execute_attribute_error(self): cursor = self._make_one(connection) with self.assertRaises(AttributeError): - cursor.execute(sql="") + cursor.execute(sql="SELECT 1") def test_execute_autocommit_off(self): from google.cloud.spanner_dbapi.utils import PeekIterator @@ -531,7 +531,7 @@ def test_fetchone_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) with mock.patch( @@ -570,7 +570,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) with mock.patch( diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 6d89a8a46a..3713ac11a8 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -391,7 +391,6 @@ def test_get_param_types_none(self): @unittest.skipIf(skip_condition, skip_message) def test_ensure_where_clause(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause cases = ( @@ -409,8 +408,7 @@ def test_ensure_where_clause(self): for sql in err_cases: with self.subTest(sql=sql): - with self.assertRaises(ProgrammingError): - ensure_where_clause(sql) + self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1") @unittest.skipIf(skip_condition, skip_message) def test_escape_name(self):