From 25654322d0683da82f3889cfd3c3aa4c769094c5 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 30 Oct 2020 18:31:07 -0400 Subject: [PATCH 01/12] feat: DB-API driver + unit tests --- google/cloud/spanner_dbapi/__init__.py | 85 +++ google/cloud/spanner_dbapi/_helpers.py | 159 ++++++ google/cloud/spanner_dbapi/connection.py | 261 +++++++++ google/cloud/spanner_dbapi/cursor.py | 329 +++++++++++ google/cloud/spanner_dbapi/exceptions.py | 94 ++++ google/cloud/spanner_dbapi/parse_utils.py | 542 +++++++++++++++++++ google/cloud/spanner_dbapi/parser.py | 246 +++++++++ google/cloud/spanner_dbapi/types.py | 98 ++++ google/cloud/spanner_dbapi/utils.py | 81 +++ google/cloud/spanner_dbapi/version.py | 11 + noxfile.py | 2 +- tests/unit/spanner_dbapi/__init__.py | 5 + tests/unit/spanner_dbapi/test__helpers.py | 130 +++++ tests/unit/spanner_dbapi/test_connection.py | 318 +++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 460 ++++++++++++++++ tests/unit/spanner_dbapi/test_globals.py | 20 + tests/unit/spanner_dbapi/test_parse_utils.py | 454 ++++++++++++++++ tests/unit/spanner_dbapi/test_parser.py | 288 ++++++++++ tests/unit/spanner_dbapi/test_types.py | 63 +++ tests/unit/spanner_dbapi/test_utils.py | 72 +++ 20 files changed, 3717 insertions(+), 1 deletion(-) create mode 100644 google/cloud/spanner_dbapi/__init__.py create mode 100644 google/cloud/spanner_dbapi/_helpers.py create mode 100644 google/cloud/spanner_dbapi/connection.py create mode 100644 google/cloud/spanner_dbapi/cursor.py create mode 100644 google/cloud/spanner_dbapi/exceptions.py create mode 100644 google/cloud/spanner_dbapi/parse_utils.py create mode 100644 google/cloud/spanner_dbapi/parser.py create mode 100644 google/cloud/spanner_dbapi/types.py create mode 100644 google/cloud/spanner_dbapi/utils.py create mode 100644 google/cloud/spanner_dbapi/version.py create mode 100644 tests/unit/spanner_dbapi/__init__.py create mode 100644 tests/unit/spanner_dbapi/test__helpers.py create mode 100644 tests/unit/spanner_dbapi/test_connection.py create mode 100644 tests/unit/spanner_dbapi/test_cursor.py create mode 100644 tests/unit/spanner_dbapi/test_globals.py create mode 100644 tests/unit/spanner_dbapi/test_parse_utils.py create mode 100644 tests/unit/spanner_dbapi/test_parser.py create mode 100644 tests/unit/spanner_dbapi/test_types.py create mode 100644 tests/unit/spanner_dbapi/test_utils.py diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py new file mode 100644 index 0000000000..7695c0058f --- /dev/null +++ b/google/cloud/spanner_dbapi/__init__.py @@ -0,0 +1,85 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Connection-based DB API for Cloud Spanner.""" + +from google.cloud.spanner_dbapi.connection import Connection +from google.cloud.spanner_dbapi.connection import connect + +from google.cloud.spanner_dbapi.cursor import Cursor + +from google.cloud.spanner_dbapi.exceptions import DatabaseError +from google.cloud.spanner_dbapi.exceptions import DataError +from google.cloud.spanner_dbapi.exceptions import Error +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import InternalError +from google.cloud.spanner_dbapi.exceptions import NotSupportedError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError +from google.cloud.spanner_dbapi.exceptions import Warning + +from google.cloud.spanner_dbapi.parse_utils import get_param_types + +from google.cloud.spanner_dbapi.types import BINARY +from google.cloud.spanner_dbapi.types import DATETIME +from google.cloud.spanner_dbapi.types import NUMBER +from google.cloud.spanner_dbapi.types import ROWID +from google.cloud.spanner_dbapi.types import STRING +from google.cloud.spanner_dbapi.types import Binary +from google.cloud.spanner_dbapi.types import Date +from google.cloud.spanner_dbapi.types import DateFromTicks +from google.cloud.spanner_dbapi.types import Time +from google.cloud.spanner_dbapi.types import TimeFromTicks +from google.cloud.spanner_dbapi.types import Timestamp +from google.cloud.spanner_dbapi.types import TimestampStr +from google.cloud.spanner_dbapi.types import TimestampFromTicks + +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT + +apilevel = "2.0" # supports DP-API 2.0 level. +paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. + +# Threads may share the module, but not connections. This is a paranoid threadsafety +# level, but it is necessary for starters to use when debugging failures. +# Eventually once transactions are working properly, we'll update the +# threadsafety level. +threadsafety = 1 + + +__all__ = [ + "Connection", + "connect", + "Cursor", + "DatabaseError", + "DataError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "Warning", + "DEFAULT_USER_AGENT", + "apilevel", + "paramstyle", + "threadsafety", + "get_param_types", + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "STRING", + "NUMBER", + "DATETIME", + "ROWID", + "TimestampStr", +] diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py new file mode 100644 index 0000000000..f581fdebbd --- /dev/null +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -0,0 +1,159 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import parse_insert +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner +from google.cloud.spanner_v1 import param_types + + +SQL_LIST_TABLES = """ + SELECT + t.table_name + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ + +SQL_GET_TABLE_COLUMN_SCHEMA = """SELECT + COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = '' + AND + TABLE_NAME = @table_name + """ + +# This table maps spanner_types to Spanner's data type sizes as per +# https://cloud.google.com/spanner/docs/data-types#allowable-types +# It is used to map `display_size` to a known type for Cursor.description +# after a row fetch. +# Since ResultMetadata +# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata +# does not send back the actual size, we have to lookup the respective size. +# Some fields' sizes are dependent upon the dynamic data hence aren't sent back +# by Cloud Spanner. +code_to_display_size = { + param_types.BOOL.code: 1, + param_types.DATE.code: 4, + param_types.FLOAT64.code: 8, + param_types.INT64.code: 8, + param_types.TIMESTAMP.code: 12, +} + + +def _execute_insert_heterogenous(transaction, sql_params_list): + for sql, params in sql_params_list: + sql, params = sql_pyformat_args_to_spanner(sql, params) + param_types = get_param_types(params) + res = transaction.execute_sql( + sql, params=params, param_types=param_types + ) + # TODO: File a bug with Cloud Spanner and the Python client maintainers + # about a lost commit when res isn't read from. + _ = list(res) + + +def _execute_insert_homogenous(transaction, parts): + # Perform an insert in one shot. + table = parts.get("table") + columns = parts.get("columns") + values = parts.get("values") + return transaction.insert(table, columns, values) + + +def handle_insert(connection, sql, params): + parts = parse_insert(sql, params) + + # The split between the two styles exists because: + # in the common case of multiple values being passed + # with simple pyformat arguments, + # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) + # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] + # we can take advantage of a single RPC with: + # transaction.insert(table, columns, values) + # instead of invoking: + # with transaction: + # for sql, params in sql_params_list: + # transaction.execute_sql(sql, params, param_types) + # which invokes more RPCs and is more costly. + + if parts.get("homogenous"): + # The common case of multiple values being passed in + # non-complex pyformat args and need to be uploaded in one RPC. + return connection.database.run_in_transaction( + _execute_insert_homogenous, parts + ) + else: + # All the other cases that are esoteric and need + # transaction.execute_sql + sql_params_list = parts.get("sql_params_list") + return connection.database.run_in_transaction( + _execute_insert_heterogenous, sql_params_list + ) + + +class ColumnInfo: + """Row column description object.""" + + def __init__( + self, + name, + type_code, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ): + self.name = name + self.type_code = type_code + self.display_size = display_size + self.internal_size = internal_size + self.precision = precision + self.scale = scale + self.null_ok = null_ok + + self.fields = ( + self.name, + self.type_code, + self.display_size, + self.internal_size, + self.precision, + self.scale, + self.null_ok, + ) + + def __repr__(self): + return self.__str__() + + def __getitem__(self, index): + return self.fields[index] + + def __str__(self): + str_repr = ", ".join( + filter( + lambda part: part is not None, + [ + "name='%s'" % self.name, + "type_code=%d" % self.type_code, + "display_size=%d" % self.display_size + if self.display_size + else None, + "internal_size=%d" % self.internal_size + if self.internal_size + else None, + "precision='%s'" % self.precision + if self.precision + else None, + "scale='%s'" % self.scale if self.scale else None, + "null_ok='%s'" % self.null_ok if self.null_ok else None, + ], + ) + ) + return "ColumnInfo(%s)" % str_repr diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py new file mode 100644 index 0000000000..b572c8573b --- /dev/null +++ b/google/cloud/spanner_dbapi/connection.py @@ -0,0 +1,261 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""DB-API Connection for the Google Cloud Spanner.""" + +import warnings + +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud import spanner_v1 as spanner + +from google.cloud.spanner_dbapi.cursor import Cursor +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT +from google.cloud.spanner_dbapi.version import PY_VERSION + + +AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" + + +class Connection: + """Representation of a DB-API connection to a Cloud Spanner database. + + You most likely don't need to instantiate `Connection` objects + directly, use the `connect` module function instead. + + :type instance: :class:`~google.cloud.spanner_v1.instance.Instance` + :param instance: Cloud Spanner instance to connect to. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database to which the connection is linked. + """ + + def __init__(self, instance, database): + self._instance = instance + self._database = database + self._ddl_statements = [] + + self._transaction = None + self._session = None + + self.is_closed = False + self._autocommit = False + + @property + def autocommit(self): + """Autocommit mode flag for this connection. + + :rtype: bool + :returns: Autocommit mode flag value. + """ + return self._autocommit + + @autocommit.setter + def autocommit(self, value): + """Change this connection autocommit mode. Setting this value to True + while a transaction is active will commit the current transaction. + + :type value: bool + :param value: New autocommit mode state. + """ + if value and not self._autocommit: + self.commit() + + self._autocommit = value + + @property + def database(self): + """Database to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.database.Database` + :returns: The related database object. + """ + return self._database + + @property + def instance(self): + """Instance to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` + :returns: The related instance object. + """ + return self._instance + + def _session_checkout(self): + """Get a Cloud Spanner session from the pool. + + If there is already a session associated with + this connection, it'll be used instead. + + :rtype: :class:`google.cloud.spanner_v1.session.Session` + :returns: Cloud Spanner session object ready to use. + """ + if not self._session: + self._session = self.database._pool.get() + + return self._session + + def _release_session(self): + """Release the currently used Spanner session. + + The session will be returned into the sessions pool. + """ + self.database._pool.put(self._session) + self._session = None + + def transaction_checkout(self): + """Get a Cloud Spanner transaction. + + Begin a new transaction, if there is no transaction in + this connection yet. Return the begun one otherwise. + + The method is non operational in autocommit mode. + + :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` + :returns: A Cloud Spanner transaction object, ready to use. + """ + if not self.autocommit: + if ( + not self._transaction + or self._transaction.committed + or self._transaction.rolled_back + ): + self._transaction = self._session_checkout().transaction() + self._transaction.begin() + + return self._transaction + + def _raise_if_closed(self): + """Helper to check the connection state before running a query. + Raises an exception if this connection is closed. + + :raises: :class:`InterfaceError`: if this connection is closed. + """ + if self.is_closed: + raise InterfaceError("connection is already closed") + + def close(self): + """Closes this connection. + + The connection will be unusable from this point forward. If the + connection has an active transaction, it will be rolled back. + """ + if ( + self._transaction + and not self._transaction.committed + and not self._transaction.rolled_back + ): + self._transaction.rollback() + + self.is_closed = True + + def commit(self): + """Commits any pending transaction to the database. + + This method is non-operational in autocommit mode. + """ + if self._autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.commit() + self._release_session() + + def rollback(self): + """Rolls back any pending transaction. + + This is a no-op if there is no active transaction or if the connection + is in autocommit mode. + """ + if self._autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.rollback() + self._release_session() + + def cursor(self): + """Factory to create a DB-API Cursor.""" + self._raise_if_closed() + + return Cursor(self) + + def run_prior_DDL_statements(self): + self._raise_if_closed() + + if self._ddl_statements: + ddl_statements = self._ddl_statements + self._ddl_statements = [] + + return self.database.update_ddl(ddl_statements).result() + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.commit() + self.close() + + +def connect( + instance_id, + database_id, + project=None, + credentials=None, + pool=None, + user_agent=None, +): + """Creates a connection to a Google Cloud Spanner database. + + :type instance_id: str + :param instance_id: The ID of the instance to connect to. + + :type database_id: str + :param database_id: The ID of the database to connect to. + + :type project: str + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. + + :type credentials: :class:`~google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to + requests. These credentials identify this application + to the service. If none are specified, the client will + attempt to ascertain the credentials from the + environment. + + :type pool: Concrete subclass of + :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. + :param pool: (Optional). Session pool to be used by database. + + :type user_agent: str + :param user_agent: (Optional) User agent to be used with this connection's + requests. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Google Cloud Spanner + resource. + + :raises: :class:`ValueError` in case of given instance/database + doesn't exist. + """ + + client_info = ClientInfo( + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, + ) + + client = spanner.Client( + project=project, credentials=credentials, client_info=client_info, + ) + + instance = client.instance(instance_id) + if not instance.exists(): + raise ValueError("instance '%s' does not exist." % instance_id) + + database = instance.database(database_id, pool=pool) + if not database.exists(): + raise ValueError("database '%s' does not exist." % database_id) + + return Connection(instance, database) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py new file mode 100644 index 0000000000..6997752a42 --- /dev/null +++ b/google/cloud/spanner_dbapi/cursor.py @@ -0,0 +1,329 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Database cursor for Google Cloud Spanner DB-API.""" + +from google.api_core.exceptions import AlreadyExists +from google.api_core.exceptions import FailedPrecondition +from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import InvalidArgument + +from collections import namedtuple + +from google.cloud import spanner_v1 as spanner + +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError + +from google.cloud.spanner_dbapi import _helpers +from google.cloud.spanner_dbapi._helpers import ColumnInfo +from google.cloud.spanner_dbapi._helpers import code_to_display_size + +from google.cloud.spanner_dbapi import parse_utils +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.utils import PeekIterator + +_UNSET_COUNT = -1 + +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) + + +class Cursor(object): + """Database cursor to manage the context of a fetch operation. + + :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: A DB-API connection to Google Cloud Spanner. + """ + + def __init__(self, connection): + self._itr = None + self._result_set = None + self._row_count = _UNSET_COUNT + self.connection = connection + self._is_closed = False + + # the number of rows to fetch at a time with fetchmany() + self.arraysize = 1 + + @property + def is_closed(self): + """The cursor close indicator. + + :rtype: bool + :returns: True if the cursor or the parent connection is closed, + otherwise False. + """ + return self._is_closed or self.connection.is_closed + + @property + def description(self): + """Read-only attribute containing a sequence of the following items: + + - ``name`` + - ``type_code`` + - ``display_size`` + - ``internal_size`` + - ``precision`` + - ``scale`` + - ``null_ok`` + """ + if not (self._result_set and self._result_set.metadata): + return None + + row_type = self._result_set.metadata.row_type + columns = [] + + for field in row_type.fields: + column_info = ColumnInfo( + name=field.name, + type_code=field.type.code, + # Size of the SQL type of the column. + display_size=code_to_display_size.get(field.type.code), + # Client perceived size of the column. + internal_size=field.ByteSize(), + ) + columns.append(column_info) + + return tuple(columns) + + @property + def rowcount(self): + """The number of rows produced by the last `.execute()`.""" + return self._row_count + + def _raise_if_closed(self): + """Raise an exception if this cursor is closed. + + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + + :raises: :class:`InterfaceError` if this cursor is closed. + """ + if self.is_closed: + raise InterfaceError("Cursor and/or connection is already closed.") + + def callproc(self, procname, args=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def close(self): + """Closes this Cursor, making it unusable from this point forward.""" + self._is_closed = True + + def _do_execute_update(self, transaction, sql, params, param_types=None): + sql = parse_utils.ensure_where_clause(sql) + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + + result = transaction.execute_update( + sql, params=params, param_types=get_param_types(params) + ) + self._itr = None + if type(result) == int: + self._row_count = result + + return result + + def execute(self, sql, args=None): + """Prepares and executes a Spanner database operation. + + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. + """ + if not self.connection: + raise ProgrammingError("Cursor is not connected to the database") + + self._raise_if_closed() + + self._result_set = None + + # Classify whether this is a read-only SQL statement. + try: + classification = parse_utils.classify_stmt(sql) + if classification == parse_utils.STMT_DDL: + self.connection._ddl_statements.append(sql) + return + + # For every other operation, we've got to ensure that + # any prior DDL statements were run. + # self._run_prior_DDL_statements() + self.connection.run_prior_DDL_statements() + + if not self.connection.autocommit: + transaction = self.connection.transaction_checkout() + + sql, params = parse_utils.sql_pyformat_args_to_spanner( + sql, args + ) + + self._result_set = transaction.execute_sql( + sql, params, param_types=get_param_types(params) + ) + self._itr = PeekIterator(self._result_set) + return + + if classification == parse_utils.STMT_NON_UPDATING: + self._handle_DQL(sql, args or None) + elif classification == parse_utils.STMT_INSERT: + _helpers.handle_insert(self.connection, sql, args or None) + else: + self.connection.database.run_in_transaction( + self._do_execute_update, sql, args or None + ) + except (AlreadyExists, FailedPrecondition) as e: + raise IntegrityError(e.details if hasattr(e, "details") else e) + except InvalidArgument as e: + raise ProgrammingError(e.details if hasattr(e, "details") else e) + except InternalServerError as e: + raise OperationalError(e.details if hasattr(e, "details") else e) + + def executemany(self, operation, seq_of_params): + """Execute the given SQL with every parameters set + from the given sequence of parameters. + + :type operation: str + :param operation: SQL code to execute. + + :type seq_of_params: list + :param seq_of_params: Sequence of additional parameters to run + the query with. + """ + self._raise_if_closed() + + for params in seq_of_params: + self.execute(operation, params) + + def fetchone(self): + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available.""" + self._raise_if_closed() + + try: + return next(self) + except StopIteration: + return None + + def fetchmany(self, size=None): + """Fetch the next set of rows of a query result, returning a sequence + of sequences. An empty sequence is returned when no more rows are available. + + :type size: int + :param size: (Optional) The maximum number of results to fetch. + + :raises InterfaceError: + if the previous call to .execute*() did not produce any result set + or if no call was issued yet. + """ + self._raise_if_closed() + + if size is None: + size = self.arraysize + + items = [] + for i in range(size): + try: + items.append(tuple(self.__next__())) + except StopIteration: + break + + return items + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() + + return list(self.__iter__()) + + def nextset(self): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setinputsizes(self, sizes): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setoutputsize(self, size, column=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def _handle_DQL(self, sql, params): + with self.connection.database.snapshot() as snapshot: + # Reference + # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + res = snapshot.execute_sql( + sql, params=params, param_types=get_param_types(params) + ) + if type(res) == int: + self._row_count = res + self._itr = None + else: + # Immediately using: + # iter(response) + # here, because this Spanner API doesn't provide + # easy mechanisms to detect when only a single item + # is returned or many, yet mixing results that + # are for .fetchone() with those that would result in + # many items returns a RuntimeError if .fetchone() is + # invoked and vice versa. + self._result_set = res + # Read the first element so that the StreamedResultSet can + # return the metadata after a DQL statement. See issue #155. + self._itr = PeekIterator(self._result_set) + # Unfortunately, Spanner doesn't seem to send back + # information about the number of rows available. + self._row_count = _UNSET_COUNT + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.close() + + def __next__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return next(self._itr) + + def __iter__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return self._itr + + def list_tables(self): + return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) + + def run_sql_in_snapshot(self, sql, params=None, param_types=None): + # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions + # hence this method exists to circumvent that limit. + self.connection.run_prior_DDL_statements() + + with self.connection.database.snapshot() as snapshot: + res = snapshot.execute_sql( + sql, params=params, param_types=param_types + ) + return list(res) + + def get_table_column_schema(self, table_name): + rows = self.run_sql_in_snapshot( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": spanner.param_types.STRING}, + ) + + column_details = {} + for column_name, is_nullable, spanner_type in rows: + column_details[column_name] = ColumnDetails( + null_ok=is_nullable == "YES", spanner_type=spanner_type + ) + return column_details diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py new file mode 100644 index 0000000000..b21be2c949 --- /dev/null +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -0,0 +1,94 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Spanner DB API exceptions.""" + + +class Warning(Exception): + """Important DB API warning.""" + + pass + + +class Error(Exception): + """The base class for all the DB API exceptions. + + Does not include :class:`Warning`. + """ + + pass + + +class InterfaceError(Error): + """ + Error related to the database interface + rather than the database itself. + """ + + pass + + +class DatabaseError(Error): + """Error related to the database.""" + + pass + + +class DataError(DatabaseError): + """ + Error due to problems with the processed data like + division by zero, numeric value out of range, etc. + """ + + pass + + +class OperationalError(DatabaseError): + """ + Error related to the database's operation, e.g. an + unexpected disconnect, the data source name is not + found, a transaction could not be processed, a + memory allocation error, etc. + """ + + pass + + +class IntegrityError(DatabaseError): + """ + Error for cases of relational integrity of the database + is affected, e.g. a foreign key check fails. + """ + + pass + + +class InternalError(DatabaseError): + """ + Internal database error, e.g. the cursor is not valid + anymore, the transaction is out of sync, etc. + """ + + pass + + +class ProgrammingError(DatabaseError): + """ + Programming error, e.g. table not found or already + exists, syntax error in the SQL statement, wrong + number of parameters specified, etc. + """ + + pass + + +class NotSupportedError(DatabaseError): + """ + Error for case of a method or database API not + supported by the database was used. + """ + + pass diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py new file mode 100644 index 0000000000..084eea315e --- /dev/null +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -0,0 +1,542 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"SQL parsing and classification utils." + +import datetime +import decimal +import re +from functools import reduce + +import sqlparse +from google.cloud import spanner_v1 as spanner + +from .exceptions import Error, ProgrammingError +from .parser import parse_values +from .types import DateStr, TimestampStr +from .utils import sanitize_literals_for_upload + +TYPES_MAP = { + bool: spanner.param_types.BOOL, + bytes: spanner.param_types.BYTES, + str: spanner.param_types.STRING, + int: spanner.param_types.INT64, + float: spanner.param_types.FLOAT64, + datetime.datetime: spanner.param_types.TIMESTAMP, + datetime.date: spanner.param_types.DATE, + DateStr: spanner.param_types.DATE, + TimestampStr: spanner.param_types.TIMESTAMP, +} + +SPANNER_RESERVED_KEYWORDS = { + "ALL", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASSERT_ROWS_MODIFIED", + "AT", + "BETWEEN", + "BY", + "CASE", + "CAST", + "COLLATE", + "CONTAINS", + "CREATE", + "CROSS", + "CUBE", + "CURRENT", + "DEFAULT", + "DEFINE", + "DESC", + "DISTINCT", + "DROP", + "ELSE", + "END", + "ENUM", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXISTS", + "EXTRACT", + "FALSE", + "FETCH", + "FOLLOWING", + "FOR", + "FROM", + "FULL", + "GROUP", + "GROUPING", + "GROUPS", + "HASH", + "HAVING", + "IF", + "IGNORE", + "IN", + "INNER", + "INTERSECT", + "INTERVAL", + "INTO", + "IS", + "JOIN", + "LATERAL", + "LEFT", + "LIKE", + "LIMIT", + "LOOKUP", + "MERGE", + "NATURAL", + "NEW", + "NO", + "NOT", + "NULL", + "NULLS", + "OF", + "ON", + "OR", + "ORDER", + "OUTER", + "OVER", + "PARTITION", + "PRECEDING", + "PROTO", + "RANGE", + "RECURSIVE", + "RESPECT", + "RIGHT", + "ROLLUP", + "ROWS", + "SELECT", + "SET", + "SOME", + "STRUCT", + "TABLESAMPLE", + "THEN", + "TO", + "TREAT", + "TRUE", + "UNBOUNDED", + "UNION", + "UNNEST", + "USING", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHIN", +} + +STMT_DDL = "DDL" +STMT_NON_UPDATING = "NON_UPDATING" +STMT_UPDATING = "UPDATING" +STMT_INSERT = "INSERT" + +# Heuristic for identifying statements that don't need to be run as updates. +RE_NON_UPDATE = re.compile(r"^\s*(SELECT)", re.IGNORECASE) + +RE_WITH = re.compile(r"^\s*(WITH)", re.IGNORECASE) + +# DDL statements follow +# https://cloud.google.com/spanner/docs/data-definition-language +RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP)", re.IGNORECASE | re.DOTALL) + +RE_IS_INSERT = re.compile(r"^\s*(INSERT)", re.IGNORECASE | re.DOTALL) + +RE_INSERT = re.compile( + # Only match the `INSERT INTO (columns...) + # otherwise the rest of the statement could be a complex + # operation. + r"^\s*INSERT INTO (?P[^\s\(\)]+)\s*\((?P[^\(\)]+)\)", + re.IGNORECASE | re.DOTALL, +) + +RE_VALUES_TILL_END = re.compile(r"VALUES\s*\(.+$", re.IGNORECASE | re.DOTALL) + +RE_VALUES_PYFORMAT = re.compile( + # To match: (%s, %s,....%s) + r"(\(\s*%s[^\(\)]+\))", + re.DOTALL, +) + +RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) + + +def classify_stmt(query): + """Determine SQL query type. + + :type query: :class:`str` + :param query: SQL query. + + :rtype: :class:`str` + :returns: Query type name. + """ + if RE_DDL.match(query): + return STMT_DDL + + if RE_IS_INSERT.match(query): + return STMT_INSERT + + if RE_NON_UPDATE.match(query) or RE_WITH.match(query): + # As of 13-March-2020, Cloud Spanner only supports WITH for DQL + # statements and doesn't yet support WITH for DML statements. + return STMT_NON_UPDATING + + return STMT_UPDATING + + +def parse_insert(insert_sql, params): + """ + Parse an INSERT statement an generate a list of tuples of the form: + [ + (SQL, params_per_row1), + (SQL, params_per_row2), + (SQL, params_per_row3), + ... + ] + + There are 4 variants of an INSERT statement: + a) INSERT INTO (columns...) VALUES (): no params + b) INSERT INTO
(columns...) SELECT_STMT: no params + c) INSERT INTO
(columns...) VALUES (%s,...): with params + d) INSERT INTO
(columns...) VALUES (%s,.....) with params and expressions + + Thus given each of the forms, it will produce a dictionary describing + how to upload the contents to Cloud Spanner: + Case a) + SQL: INSERT INTO T (f1, f2) VALUES (1, 2) + it produces: + { + 'sql_params_list': [ + ('INSERT INTO T (f1, f2) VALUES (1, 2)', None), + ], + } + + Case b) + SQL: 'INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', + it produces: + { + 'sql_params_list': [ + ('INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', None), + ] + } + + Case c) + SQL: INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s) + Params: ['a', 'b', 'c', 'd'] + it produces: + { + 'homogenous': True, + 'table': 'T', + 'columns': ['f1', 'f2'], + 'values': [('a', 'b',), ('c', 'd',)], + } + + Case d) + SQL: INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s)), (UPPER(%s), %s) + Params: ['a', 'b', 'c', 'd'] + it produces: + { + 'sql_params_list': [ + ('INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s))', ('a', 'b',)) + ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',)) + ], + } + """ # noqa + match = RE_INSERT.search(insert_sql) + + if not match: + raise ProgrammingError( + "Could not parse an INSERT statement from %s" % insert_sql + ) + + after_values_sql = RE_VALUES_TILL_END.findall(insert_sql) + if not after_values_sql: + # Case b) + insert_sql = sanitize_literals_for_upload(insert_sql) + return {"sql_params_list": [(insert_sql, None)]} + + if not params: + # Case a) perhaps? + # Check if any %s exists. + + # pyformat_str_count = after_values_sql.count("%s") + # if pyformat_str_count > 0: + # raise ProgrammingError( + # 'no params yet there are %d "%%s" tokens' % pyformat_str_count + # ) + for item in after_values_sql: + if item.count("%s") > 0: + raise ProgrammingError( + 'no params yet there are %d "%%s" tokens' + % item.count("%s") + ) + + insert_sql = sanitize_literals_for_upload(insert_sql) + # Confirmed case of: + # SQL: INSERT INTO T (a1, a2) VALUES (1, 2) + # Params: None + return {"sql_params_list": [(insert_sql, None)]} + + values_str = after_values_sql[0] + _, values = parse_values(values_str) + + if values.homogenous(): + # Case c) + + columns = [mi.strip(" `") for mi in match.group("columns").split(",")] + sql_params_list = [] + insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % ( + match.group("table_name"), + match.group("columns"), + values.argv[0], + ) + values_pyformat = [str(arg) for arg in values.argv] + rows_list = rows_for_insert_or_update(columns, params, values_pyformat) + insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble) + for row in rows_list: + sql_params_list.append((insert_sql_preamble, row)) + + return {"sql_params_list": sql_params_list} + + # Case d) + # insert_sql is of the form: + # INSERT INTO T(c1, c2) VALUES (%s, %s), (%s, LOWER(%s)) + + # Sanity check: + # length(all_args) == len(params) + args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv]) + if args_len != len(params): + raise ProgrammingError( + "Invalid length: VALUES(...) len: %d != len(params): %d" + % (args_len, len(params)) + ) + + trim_index = insert_sql.find(values_str) + before_values_sql = insert_sql[:trim_index] + + sql_param_tuples = [] + for token_arg in values.argv: + row_sql = before_values_sql + " VALUES%s" % token_arg + row_sql = sanitize_literals_for_upload(row_sql) + row_params, params = ( + tuple(params[0 : len(token_arg)]), + params[len(token_arg) :], + ) + sql_param_tuples.append((row_sql, row_params)) + + return {"sql_params_list": sql_param_tuples} + + +def rows_for_insert_or_update(columns, params, pyformat_args=None): + """ + Create a tupled list of params to be used as a single value per + value that inserted from a statement such as + SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s)' + Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] + + We'll have to convert both params types into: + Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] + """ # noqa + + if not pyformat_args: + # This is the case where we have for example: + # SQL: 'INSERT INTO t (f1, f2, f3)' + # Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + # Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] + # + # We'll have to convert both params types into: + # [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] + contains_all_list_or_tuples = True + for param in params: + if not (isinstance(param, list) or isinstance(param, tuple)): + contains_all_list_or_tuples = False + break + + if contains_all_list_or_tuples: + # The case with Params A: [(1, 2, 3), (4, 5, 6)] + # Ensure that each param's length == len(columns) + columns_len = len(columns) + for param in params: + if columns_len != len(param): + raise Error( + "\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d" + % (param, len(param), columns, columns_len) + ) + return params + else: + # The case with Params B: [1, 2, 3] + # Insert statements' params are only passed as tuples or lists, + # yet for do_execute_update, we've got to pass in list of list. + # https://googleapis.dev/python/spanner/latest/transaction-api.html\ + # #google.cloud.spanner_v1.transaction.Transaction.insert + n_stride = len(columns) + else: + # This is the case where we have for example: + # SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), + # (%s, %s, %s), (%s, %s, %s)' + # Params: [1, 2, 3, 4, 5, 6, 7, 8, 9] + # which should become + # Columns: (f1, f2, f3) + # new_params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] + + # Sanity check 1: all the pyformat_values should have the exact same + # length. + first, rest = pyformat_args[0], pyformat_args[1:] + n_stride = first.count("%s") + for pyfmt_value in rest: + n = pyfmt_value.count("%s") + if n_stride != n: + raise Error( + "\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d" + % (first, n_stride, pyfmt_value, n) + ) + + # Sanity check 2: len(params) MUST be a multiple of n_stride aka + # len(count of %s). + # so that we can properly group for example: + # Given pyformat args: + # (%s, %s, %s) + # Params: + # [1, 2, 3, 4, 5, 6, 7, 8, 9] + # into + # [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + if (len(params) % n_stride) != 0: + raise ProgrammingError( + "Invalid length: len(params)=%d MUST be a multiple of " + "len(pyformat_args)=%d" % (len(params), n_stride) + ) + + # Now chop up the strides. + strides = [] + for step in range(0, len(params), n_stride): + stride = tuple(params[step : step + n_stride :]) + strides.append(stride) + + return strides + + +def sql_pyformat_args_to_spanner(sql, params): + """ + Transform pyformat set SQL to named arguments for Cloud Spanner. + It will also unescape previously escaped format specifiers + like %%s to %s. + For example: + SQL: 'SELECT * from t where f1=%s, f2=%s, f3=%s' + Params: ('a', 23, '888***') + becomes: + SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' + Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} + + OR + SQL: 'SELECT * from t where f1=%(f1)s, f2=%(f2)s, f3=%(f3)s' + Params: {'f1': 'a', 'f2': 23, 'f3': '888***', 'extra': 'aye') + becomes: + SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' + Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} + """ + if not params: + return sanitize_literals_for_upload(sql), params + + found_pyformat_placeholders = RE_PYFORMAT.findall(sql) + params_is_dict = isinstance(params, dict) + + if params_is_dict: + if not found_pyformat_placeholders: + return sanitize_literals_for_upload(sql), params + else: + n_params = len(params) if params else 0 + n_matches = len(found_pyformat_placeholders) + if n_matches != n_params: + raise Error( + "pyformat_args mismatch\ngot %d args from %s\n" + "want %d args in %s" + % (n_matches, found_pyformat_placeholders, n_params, params) + ) + + named_args = {} + # We've now got for example: + # Case a) Params is a non-dict + # SQL: 'SELECT * from t where f1=%s, f2=%s, f3=%s' + # Params: ('a', 23, '888***') + # Case b) Params is a dict and the matches are %(value)s' + for i, pyfmt in enumerate(found_pyformat_placeholders): + key = "a%d" % i + sql = sql.replace(pyfmt, "@" + key, 1) + if params_is_dict: + # The '%(key)s' case, so interpolate it. + resolved_value = pyfmt % params + named_args[key] = resolved_value + else: + named_args[key] = cast_for_spanner(params[i]) + + return sanitize_literals_for_upload(sql), named_args + + +def cast_for_spanner(value): + """Convert the param to its Cloud Spanner equivalent type. + + :type value: Any + :param value: Value to convert to a Cloud Spanner type. + + :rtype: Any + :returns: Value converted to a Cloud Spanner type. + """ + if isinstance(value, decimal.Decimal): + return float(value) + return value + + +def get_param_types(params): + """Determine Cloud Spanner types for the given parameters. + + :type params: :class:`dict` + :param params: Parameters requiring to find Cloud Spanner types. + + :rtype: :class:`dict` + :returns: The types index for the given parameters. + """ + if params is None: + return + + param_types = {} + + for key, value in params.items(): + type_ = type(value) + if type_ in TYPES_MAP: + param_types[key] = TYPES_MAP[type_] + + return param_types + + +def ensure_where_clause(sql): + """ + Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. + Add a dummy WHERE clause if necessary. + """ + if any( + isinstance(token, sqlparse.sql.Where) + for token in sqlparse.parse(sql)[0] + ): + return sql + return sql + " WHERE 1=1" + + +def escape_name(name): + """ + Apply backticks to the name that either contain '-' or + ' ', or is a Cloud Spanner's reserved keyword. + + :type name: :class:`str` + :param name: Name to escape. + + :rtype: :class:`str` + :returns: Name escaped if it has to be escaped. + """ + if "-" in name or " " in name or name.upper() in SPANNER_RESERVED_KEYWORDS: + return "`" + name + "`" + return name diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py new file mode 100644 index 0000000000..2fc0156b57 --- /dev/null +++ b/google/cloud/spanner_dbapi/parser.py @@ -0,0 +1,246 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +""" +Grammar for parsing VALUES: + VALUES := `VALUES(` + ARGS + `)` + ARGS := [EXPR,]*EXPR + EXPR := TERMINAL / FUNC + TERMINAL := `%s` + FUNC := alphanum + `(` + ARGS + `)` + alphanum := (a-zA-Z_)[0-9a-ZA-Z_]* + +thus given: + statement: 'VALUES (%s, %s), (%s, LOWER(UPPER(%s))) , (%s)' + It'll parse: + VALUES + |- ARGS + |- (TERMINAL, TERMINAL) + |- (TERMINAL, FUNC + |- FUNC + |- (TERMINAL) + |- (TERMINAL) +""" + +from .exceptions import ProgrammingError + +ARGS = "ARGS" +EXPR = "EXPR" +FUNC = "FUNC" +VALUES = "VALUES" + + +class func(object): + def __init__(self, func_name, args): + self.name = func_name + self.args = args + + def __str__(self): + return "%s%s" % (self.name, self.args) + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if type(self) != type(other): + return False + if self.name != other.name: + return False + if not isinstance(other.args, type(self.args)): + return False + if len(self.args) != len(other.args): + return False + return self.args == other.args + + def __len__(self): + return len(self.args) + + +class terminal(str): + """ + terminal represents the unit symbol that can be part of a SQL values clause. + """ + + pass + + +class a_args(object): + def __init__(self, argv): + self.argv = argv + + def __str__(self): + return "(" + ", ".join([str(arg) for arg in self.argv]) + ")" + + def __repr__(self): + return self.__str__() + + def has_expr(self): + return any( + [token for token in self.argv if not isinstance(token, terminal)] + ) + + def __len__(self): + return len(self.argv) + + def __eq__(self, other): + if type(self) != type(other): + return False + + if len(self) != len(other): + return False + + for i, item in enumerate(self): + if item != other[i]: + return False + + return True + + def __getitem__(self, index): + return self.argv[index] + + def homogenous(self): + """ + Return True if all the arguments are pyformat + args and have the same number of arguments. + """ + if not self._is_equal_length(): + return False + + for arg in self.argv: + if isinstance(arg, terminal): + continue + elif isinstance(arg, a_args): + if not arg.homogenous(): + return False + else: + return False + return True + + def _is_equal_length(self): + """ + Return False if all the arguments have the same length. + """ + if len(self) == 0: + return True + + arg0_len = len(self.argv[0]) + for arg in self.argv[1:]: + if len(arg) != arg0_len: + return False + + return True + + +class values(a_args): + def __str__(self): + return "VALUES%s" % super().__str__() + + +def parse_values(stmt): + return expect(stmt, VALUES) + + +pyfmt_str = terminal("%s") + + +def expect(word, token): + word = word.strip() + if token == VALUES: + if not word.startswith("VALUES"): + raise ProgrammingError( + "VALUES: `%s` does not start with VALUES" % word + ) + word = word[len("VALUES") :].lstrip() + + all_args = [] + while word: + word = word.strip() + + word, arg = expect(word, ARGS) + all_args.append(arg) + word = word.strip() + + if word and not word.startswith(","): + raise ProgrammingError( + "VALUES: expected `,` got %s in %s" % (word[0], word) + ) + word = word[1:] + return "", values(all_args) + + elif token == FUNC: + begins_with_letter = word and (word[0].isalpha() or word[0] == "_") + if not begins_with_letter: + raise ProgrammingError( + "FUNC: `%s` does not begin with `a-zA-z` nor a `_`" % word + ) + + rest = word[1:] + end = 0 + for ch in rest: + if ch.isalnum() or ch == "_": + end += 1 + else: + break + + func_name, rest = word[: end + 1], word[end + 1 :].strip() + + word, args = expect(rest, ARGS) + return word, func(func_name, args) + + elif token == ARGS: + # The form should be: + # (%s) + # (%s, %s...) + # (FUNC, %s...) + # (%s, %s...) + if not (word and word.startswith("(")): + raise ProgrammingError( + "ARGS: supposed to begin with `(` in `%s`" % word + ) + + word = word[1:] + + terms = [] + while True: + word = word.strip() + if not word or word.startswith(")"): + break + + if word == "%s": + terms.append(pyfmt_str) + word = "" + elif not word.startswith("%s"): + word, parsed = expect(word, FUNC) + terms.append(parsed) + else: + terms.append(pyfmt_str) + word = word[2:].strip() + + if word.startswith(","): + word = word[1:] + + if not (word and word.startswith(")")): + raise ProgrammingError( + "ARGS: supposed to end with `)` in `%s`" % word + ) + + word = word[1:] + return word, a_args(terms) + + elif token == EXPR: + if word == "%s": + # Terminal symbol. + return "", pyfmt_str + + # Otherwise we expect a function. + return expect(word, FUNC) + + raise ProgrammingError("Unknown token `%s`" % token) + + +def as_values(values_stmt): + _, _values = parse_values(values_stmt) + return _values diff --git a/google/cloud/spanner_dbapi/types.py b/google/cloud/spanner_dbapi/types.py new file mode 100644 index 0000000000..8c6bd27577 --- /dev/null +++ b/google/cloud/spanner_dbapi/types.py @@ -0,0 +1,98 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Implementation of the type objects and constructors according to the + PEP-0249 specification. + + See + https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors +""" + +import datetime +import time +from base64 import b64encode + + +def _date_from_ticks(ticks): + """Based on PEP-249 Implementation Hints for Module Authors: + + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + return Date(*time.localtime(ticks)[:3]) + + +def _time_from_ticks(ticks): + """Based on PEP-249 Implementation Hints for Module Authors: + + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + return Time(*time.localtime(ticks)[3:6]) + + +def _timestamp_from_ticks(ticks): + """Based on PEP-249 Implementation Hints for Module Authors: + + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + return Timestamp(*time.localtime(ticks)[:6]) + + +class _DBAPITypeObject(object): + """Implementation of a helper class used for type comparison among similar + but possibly different types. + + See + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + + def __init__(self, *values): + self.values = values + + def __eq__(self, other): + return other in self.values + + +Date = datetime.date +Time = datetime.time +Timestamp = datetime.datetime +DateFromTicks = _date_from_ticks +TimeFromTicks = _time_from_ticks +TimestampFromTicks = _timestamp_from_ticks +Binary = b64encode + +STRING = "STRING" +BINARY = _DBAPITypeObject("TYPE_CODE_UNSPECIFIED", "BYTES", "ARRAY", "STRUCT") +NUMBER = _DBAPITypeObject("BOOL", "INT64", "FLOAT64", "NUMERIC") +DATETIME = _DBAPITypeObject("TIMESTAMP", "DATE") +ROWID = "STRING" + + +class TimestampStr(str): + """[inherited from the alpha release] + + TODO: Decide whether this class is necessary + + TimestampStr exists so that we can purposefully format types as timestamps + compatible with Cloud Spanner's TIMESTAMP type, but right before making + queries, it'll help differentiate between normal strings and the case of + types that should be TIMESTAMP. + """ + + pass + + +class DateStr(str): + """[inherited from the alpha release] + + TODO: Decide whether this class is necessary + + DateStr is a sentinel type to help format Django dates as + compatible with Cloud Spanner's DATE type, but right before making + queries, it'll help differentiate between normal strings and the case of + types that should be DATE. + """ + + pass diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py new file mode 100644 index 0000000000..f4769e80a4 --- /dev/null +++ b/google/cloud/spanner_dbapi/utils.py @@ -0,0 +1,81 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import re + + +class PeekIterator: + """ + PeekIterator peeks at the first element out of an iterator + for the sake of operations like auto-population of fields on reading + the first element. + If next's result is an instance of list, it'll be converted into a tuple + to conform with DBAPI v2's sequence expectations. + """ + + def __init__(self, source): + itr_src = iter(source) + + self.__iters = [] + self.__index = 0 + + try: + head = next(itr_src) + # Restitch and prepare to read from multiple iterators. + self.__iters = [iter(itr) for itr in [[head], itr_src]] + except StopIteration: + pass + + def __next__(self): + if self.__index >= len(self.__iters): + raise StopIteration + + iterator = self.__iters[self.__index] + try: + head = next(iterator) + except StopIteration: + # That iterator has been exhausted, try with the next one. + self.__index += 1 + return self.__next__() + else: + return tuple(head) if isinstance(head, list) else head + + def __iter__(self): + return self + + +re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") + + +def backtick_unicode(sql): + matches = list(re_UNICODE_POINTS.finditer(sql)) + if not matches: + return sql + + segments = [] + + last_end = 0 + for match in matches: + start, end = match.span() + if sql[start] != "`" and sql[end - 1] != "`": + segments.append(sql[last_end:start] + "`" + sql[start:end] + "`") + else: + segments.append(sql[last_end:end]) + + last_end = end + + return "".join(segments) + + +def sanitize_literals_for_upload(s): + """ + Convert literals in s, to be fit for consumption by Cloud Spanner. + 1. Convert %% (escaped percent literals) to %. Percent signs must be escaped when + values like %s are used as SQL parameter placeholders but Spanner's query language + uses placeholders like @a0 and doesn't expect percent signs to be escaped. + 2. Quote words containing non-ASCII, with backticks, for example föö to `föö`. + """ + return backtick_unicode(s.replace("%%", "%")) diff --git a/google/cloud/spanner_dbapi/version.py b/google/cloud/spanner_dbapi/version.py new file mode 100644 index 0000000000..88d8f7cdaf --- /dev/null +++ b/google/cloud/spanner_dbapi/version.py @@ -0,0 +1,11 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import platform + +PY_VERSION = platform.python_version() +VERSION = "2.2.0a1" +DEFAULT_USER_AGENT = "django_spanner/" + VERSION diff --git a/noxfile.py b/noxfile.py index cdd18ff886..bebc3aab48 100644 --- a/noxfile.py +++ b/noxfile.py @@ -65,7 +65,7 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. - session.install("mock", "pytest", "pytest-cov") + session.install("mock", "pytest", "pytest-cov", "sqlparse") if session.python != "2.7": session.install("-e", ".[tracing]") diff --git a/tests/unit/spanner_dbapi/__init__.py b/tests/unit/spanner_dbapi/__init__.py new file mode 100644 index 0000000000..6b607710ed --- /dev/null +++ b/tests/unit/spanner_dbapi/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py new file mode 100644 index 0000000000..e5316d254e --- /dev/null +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -0,0 +1,130 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import unittest + +from unittest import mock + + +class TestHelpers(unittest.TestCase): + def test__execute_insert_heterogenous(self): + from google.cloud.spanner_dbapi import _helpers + + sql = "sql" + params = (sql, None) + with mock.patch( + "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", + return_value=params, + ) as mock_pyformat: + with mock.patch( + "google.cloud.spanner_dbapi._helpers.get_param_types", + return_value=None, + ) as mock_param_types: + transaction = mock.MagicMock() + transaction.execute_sql = mock_execute = mock.MagicMock() + _helpers._execute_insert_heterogenous(transaction, [params]) + + mock_pyformat.assert_called_once_with(params[0], params[1]) + mock_param_types.assert_called_once_with(None) + mock_execute.assert_called_once_with( + sql, params=None, param_types=None + ) + + def test__execute_insert_homogenous(self): + from google.cloud.spanner_dbapi import _helpers + + transaction = mock.MagicMock() + transaction.insert = mock.MagicMock() + parts = mock.MagicMock() + parts.get = mock.MagicMock(return_value=0) + + _helpers._execute_insert_homogenous(transaction, parts) + transaction.insert.assert_called_once_with(0, 0, 0) + + def test_handle_insert(self): + from google.cloud.spanner_dbapi import _helpers + + connection = mock.MagicMock() + connection.database.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + parts = mock.MagicMock() + with mock.patch( + "google.cloud.spanner_dbapi._helpers.parse_insert", + return_value=parts, + ): + parts.get = mock.MagicMock(return_value=True) + mock_run_in.return_value = 0 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 0) + + parts.get = mock.MagicMock(return_value=False) + mock_run_in.return_value = 1 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 1) + + +class TestColumnInfo(unittest.TestCase): + def test_ctor(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + name = "col-name" + type_code = 8 + display_size = 5 + internal_size = 10 + precision = 3 + scale = None + null_ok = False + + cols = ColumnInfo( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ) + + self.assertEqual(cols.name, name) + self.assertEqual(cols.type_code, type_code) + self.assertEqual(cols.display_size, display_size) + self.assertEqual(cols.internal_size, internal_size) + self.assertEqual(cols.precision, precision) + self.assertEqual(cols.scale, scale) + self.assertEqual(cols.null_ok, null_ok) + self.assertEqual( + cols.fields, + ( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ), + ) + + def test___get_item__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + fields = ("col-name", 8, 5, 10, 3, None, False) + cols = ColumnInfo(*fields) + + for i in range(0, 7): + self.assertEqual(cols[i], fields[i]) + + def test___str__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) + + self.assertEqual( + str(cols), + "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", + ) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py new file mode 100644 index 0000000000..d545472c57 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -0,0 +1,318 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import unittest +import warnings + +from unittest import mock + + +def _make_credentials(): + from google.auth import credentials + + class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class TestConnection(unittest.TestCase): + + PROJECT = "test-project" + INSTANCE = "test-instance" + DATABASE = "test-database" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + + def _get_client_info(self): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo(user_agent=self.USER_AGENT) + + def _make_connection(self): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_v1.instance import Instance + + # We don't need a real Client object to test the constructor + instance = Instance(self.INSTANCE, client=None) + database = instance.database(self.DATABASE) + return Connection(instance, database) + + def test_property_autocommit_setter(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = True + mock_commit.assert_called_once_with() + self.assertEqual(connection._autocommit, True) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = False + mock_commit.assert_not_called() + self.assertEqual(connection._autocommit, False) + + def test_property_database(self): + from google.cloud.spanner_v1.database import Database + + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + def test_property_instance(self): + from google.cloud.spanner_v1.instance import Instance + + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + + def test__session_checkout(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.get = mock.MagicMock( + return_value="db_session_pool" + ) + connection = Connection(self.INSTANCE, mock_database) + + connection._session_checkout() + mock_database._pool.get.assert_called_once_with() + self.assertEqual(connection._session, "db_session_pool") + + connection._session = "db_session" + connection._session_checkout() + self.assertEqual(connection._session, "db_session") + + def test__release_session(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.put = mock.MagicMock() + connection = Connection(self.INSTANCE, mock_database) + connection._session = "session" + + connection._release_session() + mock_database._pool.put.assert_called_once_with("session") + self.assertIsNone(connection._session) + + def test_transaction_checkout(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + connection._session_checkout = mock_checkout = mock.MagicMock( + autospec=True + ) + connection.transaction_checkout() + mock_checkout.assert_called_once_with() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + self.assertEqual(connection.transaction_checkout(), mock_transaction) + + connection._autocommit = True + self.assertIsNone(connection.transaction_checkout()) + + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) + + with self.assertRaises(InterfaceError): + connection.cursor() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + mock_transaction.rollback = mock_rollback = mock.MagicMock() + connection.close() + mock_rollback.assert_called_once_with() + + @mock.patch.object(warnings, "warn") + def test_commit(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.commit = mock_commit = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_commit.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.commit() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + @mock.patch.object(warnings, "warn") + def test_rollback(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.rollback = mock_rollback = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_rollback.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.rollback() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + def test_run_prior_DDL_statements(self): + from google.cloud.spanner_dbapi import Connection, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.database.Database", autospec=True, + ) as mock_database: + connection = Connection(self.INSTANCE, mock_database) + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_not_called() + + ddl = ["ddl"] + connection._ddl_statements = ddl + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_called_once_with(ddl) + + connection.is_closed = True + + with self.assertRaises(InterfaceError): + connection.run_prior_DDL_statements() + + def test_context(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + with connection as conn: + self.assertEqual(conn, connection) + + self.assertTrue(connection.is_closed) + + def test_connect(self): + from google.cloud.spanner_dbapi import Connection, connect + + with mock.patch("google.cloud.spanner_v1.Client"): + with mock.patch( + "google.api_core.gapic_v1.client_info.ClientInfo", + return_value=self._get_client_info(), + ): + connection = connect( + self.INSTANCE, + self.DATABASE, + self.PROJECT, + self.CREDENTIALS, + self.USER_AGENT, + ) + self.assertIsInstance(connection, Connection) + + def test_connect_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=False, + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_connect_database_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=False, + ): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..09288df94e --- /dev/null +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -0,0 +1,460 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cursor() class unit tests.""" + +import unittest + +from unittest import mock + + +class TestCursor(unittest.TestCase): + + INSTANCE = "test-instance" + DATABASE = "test-database" + + def _get_target_class(self): + from google.cloud.spanner_dbapi import Cursor + + return Cursor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def _make_connection(self, *args, **kwargs): + from google.cloud.spanner_dbapi import Connection + + return Connection(*args, **kwargs) + + def test_property_connection(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertEqual(cursor.connection, connection) + + def test_property_description(self): + from google.cloud.spanner_dbapi._helpers import ColumnInfo + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + self.assertIsNone(cursor.description) + cursor._result_set = res_set = mock.MagicMock() + res_set.metadata.row_type.fields = [mock.MagicMock()] + self.assertIsNotNone(cursor.description) + self.assertIsInstance(cursor.description[0], ColumnInfo) + + def test_property_rowcount(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertEqual(cursor.rowcount, _UNSET_COUNT) + + def test_callproc(self): + from google.cloud.spanner_dbapi.exceptions import InterfaceError + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + cursor._is_closed = True + with self.assertRaises(InterfaceError): + cursor.callproc(procname=None) + + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect(self.INSTANCE, self.DATABASE) + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_do_execute_update(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + transaction = mock.MagicMock() + + def run_helper(ret_value): + transaction.execute_update.return_value = ret_value + res = cursor._do_execute_update( + transaction=transaction, sql="sql", params=None, + ) + return res + + expected = "good" + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + expected = 1234 + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, expected) + + def test_execute_programming_error(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + cursor.connection = None + with self.assertRaises(ProgrammingError): + cursor.execute(sql="") + + def test_execute_attribute_error(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + with self.assertRaises(AttributeError): + cursor.execute(sql="") + + def test_execute_autocommit_off(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.connection._autocommit = False + cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) + + cursor.execute("sql") + self.assertIsInstance(cursor._result_set, mock.MagicMock) + self.assertIsInstance(cursor._itr, PeekIterator) + + def test_execute_statement(self): + from google.cloud.spanner_dbapi import parse_utils + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_DDL, + ) as mock_classify_stmt: + sql = "sql" + cursor.execute(sql=sql) + mock_classify_stmt.assert_called_once_with(sql) + self.assertEqual(cursor.connection._ddl_statements, [sql]) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_NON_UPDATING, + ): + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", + return_value=parse_utils.STMT_NON_UPDATING, + ) as mock_handle_ddl: + connection.autocommit = True + sql = "sql" + cursor.execute(sql=sql) + mock_handle_ddl.assert_called_once_with(sql, None) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_INSERT, + ): + with mock.patch( + "google.cloud.spanner_dbapi._helpers.handle_insert", + return_value=parse_utils.STMT_INSERT, + ) as mock_handle_insert: + sql = "sql" + cursor.execute(sql=sql) + mock_handle_insert.assert_called_once_with( + connection, sql, None + ) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value="other_statement", + ): + cursor.connection._database = mock_db = mock.MagicMock() + mock_db.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + cursor.execute(sql=sql) + mock_run_in.assert_called_once_with( + cursor._do_execute_update, sql, None + ) + + def test_execute_integrity_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import IntegrityError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.AlreadyExists("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.FailedPrecondition("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + def test_execute_invalid_argument(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InvalidArgument("message"), + ): + with self.assertRaises(ProgrammingError): + cursor.execute(sql="sql") + + def test_execute_internal_server_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import OperationalError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InternalServerError("message"), + ): + with self.assertRaises(OperationalError): + cursor.execute(sql="sql") + + def test_executemany_on_closed_cursor(self): + from google.cloud.spanner_dbapi import InterfaceError + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor.close() + + with self.assertRaises(InterfaceError): + cursor.executemany( + """SELECT * FROM table1 WHERE "col1" = @a1""", () + ) + + def test_executemany(self): + from google.cloud.spanner_dbapi import connect + + operation = """SELECT * FROM table1 WHERE "col1" = @a1""" + params_seq = ((1,), (2,)) + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.execute" + ) as execute_mock: + cursor.executemany(operation, params_seq) + + execute_mock.assert_has_calls( + (mock.call(operation, (1,)), mock.call(operation, (2,))) + ) + + def test_fetchone(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [1, 2, 3] + cursor._itr = iter(lst) + for i in range(len(lst)): + self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) + + def test_fetchmany(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + + self.assertEqual(cursor.fetchmany(), [lst[0]]) + + result = cursor.fetchmany(len(lst)) + self.assertEqual(result, lst[1:]) + + def test_fetchall(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + self.assertEqual(cursor.fetchall(), lst) + + def test_nextset(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.nextset() + + def test_setinputsizes(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setinputsizes(sizes=None) + + def test_setoutputsize(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setoutputsize(size=None) + + # def test_handle_insert(self): + # pass + # + # def test_do_execute_insert_heterogenous(self): + # pass + # + # def test_do_execute_insert_homogenous(self): + # pass + + def test_handle_dql(self): + from google.cloud.spanner_dbapi import utils + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + mock_snapshot.execute_sql.return_value = int(0) + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._row_count, 0) + self.assertIsNone(cursor._itr) + + mock_snapshot.execute_sql.return_value = "0" + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._result_set, "0") + self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + def test_context(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with cursor as c: + self.assertEqual(c, cursor) + + self.assertTrue(c.is_closed) + + def test_next(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + cursor.__next__() + + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + i = 0 + for c in cursor._itr: + self.assertEqual(c, lst[i]) + i += 1 + + def test_iter(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + _ = iter(cursor) + + iterator = iter([(1,), (2,), (3,)]) + cursor._itr = iterator + self.assertEqual(iter(cursor), iterator) + + def test_list_tables(self): + from google.cloud.spanner_dbapi import _helpers + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + table_list = ["table1", "table2", "table3"] + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=table_list, + ) as mock_run_sql: + cursor.list_tables() + mock_run_sql.assert_called_once_with(_helpers.SQL_LIST_TABLES) + + def test_run_sql_in_snapshot(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + results = 1, 2, 3 + mock_snapshot.execute_sql.return_value = results + self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results)) + + def test_get_table_column_schema(self): + from google.cloud.spanner_dbapi.cursor import ColumnDetails + from google.cloud.spanner_dbapi import _helpers + from google.cloud.spanner_v1 import param_types + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + column_name = "column1" + is_nullable = "YES" + spanner_type = "spanner_type" + rows = [(column_name, is_nullable, spanner_type)] + expected = { + column_name: ColumnDetails( + null_ok=True, spanner_type=spanner_type, + ) + } + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=rows, + ) as mock_run_sql: + table_name = "table1" + result = cursor.get_table_column_schema(table_name=table_name) + mock_run_sql.assert_called_once_with( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": param_types.STRING}, + ) + self.assertEqual(result, expected) diff --git a/tests/unit/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py new file mode 100644 index 0000000000..3f8360e2ea --- /dev/null +++ b/tests/unit/spanner_dbapi/test_globals.py @@ -0,0 +1,20 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + + +class TestDBAPIGlobals(unittest.TestCase): + def test_apilevel(self): + from google.cloud.spanner_dbapi import apilevel + from google.cloud.spanner_dbapi import paramstyle + from google.cloud.spanner_dbapi import threadsafety + + self.assertEqual(apilevel, "2.0", "We implement PEP-0249 version 2.0") + self.assertEqual(paramstyle, "format", "Cloud Spanner uses @param") + self.assertEqual( + threadsafety, 1, "Threads may share module but not connections" + ) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py new file mode 100644 index 0000000000..1bd38c85eb --- /dev/null +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -0,0 +1,454 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + +from google.cloud.spanner_v1 import param_types + + +class TestParseUtils(unittest.TestCase): + def test_classify_stmt(self): + from google.cloud.spanner_dbapi.parse_utils import STMT_DDL + from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT + from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING + from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING + from google.cloud.spanner_dbapi.parse_utils import classify_stmt + + cases = ( + ("SELECT 1", STMT_NON_UPDATING), + ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), + ( + "WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq", + STMT_NON_UPDATING, + ), + ( + "CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) " + "NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)", + STMT_DDL, + ), + ( + "CREATE INDEX SongsBySingerAlbumSongNameDesc ON " + "Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums", + STMT_DDL, + ), + ("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL), + ( + "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", + STMT_DDL, + ), + ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), + ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), + ) + + for query, want_class in cases: + self.assertEqual(classify_stmt(query), want_class) + + def test_parse_insert(self): + from google.cloud.spanner_dbapi.parse_utils import parse_insert + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with self.assertRaises(ProgrammingError): + parse_insert("bad-sql", None) + + cases = [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + [1, 2, 3, 4, 5, 6], + { + "sql_params_list": [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (1, 2, 3), + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (4, 5, 6), + ), + ] + }, + ), + ( + "INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)", + [1, 2, 3, 4, 5, 6], + { + "sql_params_list": [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (1, 2, 3), + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (4, 5, 6), + ), + ] + }, + ), + ( + "INSERT INTO sales.addresses (street, city, state, zip_code) " + "SELECT street, city, state, zip_code FROM sales.customers" + "ORDER BY first_name, last_name", + None, + { + "sql_params_list": [ + ( + "INSERT INTO sales.addresses (street, city, state, zip_code) " + "SELECT street, city, state, zip_code FROM sales.customers" + "ORDER BY first_name, last_name", + None, + ) + ] + }, + ), + ( + "INSERT INTO ap (n, ct, cn) " + "VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)", + (1, 2, 3, 4, 5, 6, 7, 8, 9), + { + "sql_params_list": [ + ( + "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", + (1, 2, 3), + ), + ( + "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", + (4, 5, 6), + ), + ( + "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", + (7, 8, 9), + ), + ] + }, + ), + ( + "INSERT INTO `no` (`yes`) VALUES (%s)", + (1, 4, 5), + { + "sql_params_list": [ + ("INSERT INTO `no` (`yes`) VALUES (%s)", (1,)), + ("INSERT INTO `no` (`yes`) VALUES (%s)", (4,)), + ("INSERT INTO `no` (`yes`) VALUES (%s)", (5,)), + ] + }, + ), + ( + "INSERT INTO T (f1, f2) VALUES (1, 2)", + None, + { + "sql_params_list": [ + ("INSERT INTO T (f1, f2) VALUES (1, 2)", None) + ] + }, + ), + ( + "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)", + (1, "FOO", 5, 10, 11, 29), + { + "sql_params_list": [ + ( + "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))", + (1, "FOO"), + ), + ( + "INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", + (5, 10), + ), + ( + "INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", + (11, 29), + ), + ] + }, + ), + ] + + sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" + with self.assertRaises(ProgrammingError): + parse_insert(sql, None) + + for sql, params, want in cases: + with self.subTest(sql=sql): + got = parse_insert(sql, params) + self.assertEqual( + got, want, "Mismatch with parse_insert of `%s`" % sql + ) + + def test_parse_insert_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import parse_insert + + cases = [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", + [1, 2, 3, 4, 5, 6, 7], + "len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3", + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))", + [1, 2, 3, 4, 5, 6, 7], + "Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7", + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))", + [1, 2, 3, 4, 5, 6], + "VALUES: expected `,` got \\) in \\)", + ), + ] + + for sql, params, wantException in cases: + with self.subTest(sql=sql): + self.assertRaisesRegex( + exceptions.ProgrammingError, + wantException, + lambda: parse_insert(sql, params), + ) + + def test_rows_for_insert_or_update(self): + from google.cloud.spanner_dbapi.parse_utils import ( + rows_for_insert_or_update, + ) + from google.cloud.spanner_dbapi.exceptions import Error + + with self.assertRaises(Error): + rows_for_insert_or_update([0], [[]]) + + with self.assertRaises(Error): + rows_for_insert_or_update([0], None, ["0", "%s"]) + + cases = [ + ( + ["id", "app", "name"], + [(5, "ap", "n"), (6, "bp", "m")], + None, + [(5, "ap", "n"), (6, "bp", "m")], + ), + ( + ["app", "name"], + [("ap", "n"), ("bp", "m")], + None, + [("ap", "n"), ("bp", "m")], + ), + ( + ["app", "name", "fn"], + ["ap", "n", "f1", "bp", "m", "f2", "cp", "o", "f3"], + ["(%s, %s, %s)", "(%s, %s, %s)", "(%s, %s, %s)"], + [("ap", "n", "f1"), ("bp", "m", "f2"), ("cp", "o", "f3")], + ), + ( + ["app", "name", "fn", "ln"], + [ + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), + ], + None, + [ + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), + ], + ), + ( + ["app", "name", "fn"], + ["ap", "n", "f1"], + None, + [("ap", "n", "f1")], + ), + ] + + for i, (columns, params, pyformat_args, want) in enumerate(cases): + with self.subTest(i=i): + got = rows_for_insert_or_update(columns, params, pyformat_args) + self.assertEqual(got, want) + + def test_sql_pyformat_args_to_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import ( + sql_pyformat_args_to_spanner, + ) + + cases = [ + ( + ( + "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s", + (10, "abc", "y**$22l3f"), + ), + ( + "SELECT * from t WHERE f1=@a0, f2 = @a1, f3=@a2", + {"a0": 10, "a1": "abc", "a2": "y**$22l3f"}, + ), + ), + ( + ( + "INSERT INTO t (f1, f2, f2) VALUES (%s, %s, %s)", + ("app", "name", "applied"), + ), + ( + "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", + {"a0": "app", "a1": "name", "a2": "applied"}, + ), + ), + ( + ( + "INSERT INTO t (f1, f2, f2) VALUES (%(f1)s, %(f2)s, %(f3)s)", + {"f1": "app", "f2": "name", "f3": "applied"}, + ), + ( + "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", + {"a0": "app", "a1": "name", "a2": "applied"}, + ), + ), + ( + # Intentionally using a dict with more keys than will be resolved. + ( + "SELECT * from t WHERE f1=%(f1)s", + {"f1": "app", "f2": "name"}, + ), + ("SELECT * from t WHERE f1=@a0", {"a0": "app"}), + ), + ( + # No args to replace, we MUST return the original params dict + # since it might be useful to pass to the next user. + ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), + ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), + ), + ( + ( + "SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s", + (1, 1.0, decimal.Decimal("31")), + ), + ( + "SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2", + {"a0": 1, "a1": 1.0, "a2": 31.0}, + ), + ), + ] + for ((sql_in, params), sql_want) in cases: + with self.subTest(sql=sql_in): + got_sql, got_named_args = sql_pyformat_args_to_spanner( + sql_in, params + ) + want_sql, want_named_args = sql_want + self.assertEqual(got_sql, want_sql, "SQL does not match") + self.assertEqual( + got_named_args, want_named_args, "Named args do not match" + ) + + def test_sql_pyformat_args_to_spanner_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import ( + sql_pyformat_args_to_spanner, + ) + + cases = [ + ( + "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s", + (10, "abc", "y**$22l3f"), + ) + ] + for sql, params in cases: + with self.subTest(sql=sql): + self.assertRaisesRegex( + exceptions.Error, + "pyformat_args mismatch", + lambda: sql_pyformat_args_to_spanner(sql, params), + ) + + def test_cast_for_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner + + value = decimal.Decimal(3) + self.assertEqual(cast_for_spanner(value), float(3.0)) + self.assertEqual(cast_for_spanner(5), 5) + self.assertEqual(cast_for_spanner("string"), "string") + + def test_get_param_types(self): + import datetime + + from google.cloud.spanner_dbapi.parse_utils import DateStr + from google.cloud.spanner_dbapi.parse_utils import TimestampStr + from google.cloud.spanner_dbapi.parse_utils import get_param_types + + params = { + "a1": 10, + "b1": "string", + "c1": 10.39, + "d1": TimestampStr("2005-08-30T01:01:01.000001Z"), + "e1": DateStr("2019-12-05"), + "f1": True, + "g1": datetime.datetime(2011, 9, 1, 13, 20, 30), + "h1": datetime.date(2011, 9, 1), + "i1": b"bytes", + "j1": None, + } + want_types = { + "a1": param_types.INT64, + "b1": param_types.STRING, + "c1": param_types.FLOAT64, + "d1": param_types.TIMESTAMP, + "e1": param_types.DATE, + "f1": param_types.BOOL, + "g1": param_types.TIMESTAMP, + "h1": param_types.DATE, + "i1": param_types.BYTES, + } + got_types = get_param_types(params) + self.assertEqual(got_types, want_types) + + def test_get_param_types_none(self): + from google.cloud.spanner_dbapi.parse_utils import get_param_types + + self.assertEqual(get_param_types(None), None) + + def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause + + cases = [ + ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + ), + ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", + ), + ( + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + ), + ( + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ), + ( + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ), + ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), + ] + + for sql, want in cases: + with self.subTest(sql=sql): + got = ensure_where_clause(sql) + self.assertEqual(got, want) + + def test_escape_name(self): + from google.cloud.spanner_dbapi.parse_utils import escape_name + + cases = ( + ("SELECT", "`SELECT`"), + ("dashed-value", "`dashed-value`"), + ("with space", "`with space`"), + ("name", "name"), + ("", ""), + ) + for name, want in cases: + with self.subTest(name=name): + got = escape_name(name) + self.assertEqual(got, want) diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py new file mode 100644 index 0000000000..d5baf9d824 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -0,0 +1,288 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + +from unittest import mock + + +class TestParser(unittest.TestCase): + def test_func(self): + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + + cases = [ + ("_91())", ")", func("_91", a_args([]))), + ("_a()", "", func("_a", a_args([]))), + ("___()", "", func("___", a_args([]))), + ("abc()", "", func("abc", a_args([]))), + ( + "AF112(%s, LOWER(%s, %s), rand(%s, %s, TAN(%s, %s)))", + "", + func( + "AF112", + a_args( + [ + pyfmt_str, + func("LOWER", a_args([pyfmt_str, pyfmt_str])), + func( + "rand", + a_args( + [ + pyfmt_str, + pyfmt_str, + func( + "TAN", + a_args([pyfmt_str, pyfmt_str]), + ), + ] + ), + ), + ] + ), + ), + ), + ] + + for text, want_unconsumed, want_parsed in cases: + with self.subTest(text=text): + got_unconsumed, got_parsed = expect(text, FUNC) + self.assertEqual(got_parsed, want_parsed) + self.assertEqual(got_unconsumed, want_unconsumed) + + def test_func_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + + cases = [ + ("", "FUNC: `` does not begin with `a-zA-z` nor a `_`"), + ("91", "FUNC: `91` does not begin with `a-zA-z` nor a `_`"), + ("_91", "supposed to begin with `\\(`"), + ("_91(", "supposed to end with `\\)`"), + ("_.()", "supposed to begin with `\\(`"), + ("_a.b()", "supposed to begin with `\\(`"), + ] + + for text, wantException in cases: + with self.subTest(text=text): + self.assertRaisesRegex( + ProgrammingError, wantException, lambda: expect(text, FUNC) + ) + + def test_func_eq(self): + from google.cloud.spanner_dbapi.parser import func + + func1 = func("func1", None) + func2 = func("func2", None) + self.assertFalse(func1 == object) + self.assertFalse(func1 == func2) + func2.name = func1.name + func1.args = 0 + func2.args = "0" + self.assertFalse(func1 == func2) + func1.args = [0] + func2.args = [0, 0] + self.assertFalse(func1 == func2) + func2.args = func1.args + self.assertTrue(func1 == func2) + + def test_a_args(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + + cases = [ + ("()", "", a_args([])), + ("(%s)", "", a_args([pyfmt_str])), + ("(%s,)", "", a_args([pyfmt_str])), + ("(%s),", ",", a_args([pyfmt_str])), + ( + "(%s,%s, f1(%s, %s))", + "", + a_args( + [ + pyfmt_str, + pyfmt_str, + func("f1", a_args([pyfmt_str, pyfmt_str])), + ] + ), + ), + ] + + for text, want_unconsumed, want_parsed in cases: + with self.subTest(text=text): + got_unconsumed, got_parsed = expect(text, ARGS) + self.assertEqual(got_parsed, want_parsed) + self.assertEqual(got_unconsumed, want_unconsumed) + + def test_a_args_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import expect + + cases = [ + ("", "ARGS: supposed to begin with `\\(`"), + ("(", "ARGS: supposed to end with `\\)`"), + (")", "ARGS: supposed to begin with `\\(`"), + ("(%s,%s, f1(%s, %s), %s", "ARGS: supposed to end with `\\)`"), + ] + + for text, wantException in cases: + with self.subTest(text=text): + self.assertRaisesRegex( + ProgrammingError, wantException, lambda: expect(text, ARGS) + ) + + def test_a_args_has_expr(self): + from google.cloud.spanner_dbapi.parser import a_args + + self.assertFalse(a_args([]).has_expr()) + self.assertTrue(a_args([[0]]).has_expr()) + + def test_a_args_eq(self): + from google.cloud.spanner_dbapi.parser import a_args + + a1 = a_args([0]) + self.assertFalse(a1 == object()) + a2 = a_args([0, 0]) + self.assertFalse(a1 == a2) + a1.argv = [0, 1] + self.assertFalse(a1 == a2) + a2.argv = [0, 1] + self.assertTrue(a1 == a2) + + def test_a_args_homogeneous(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) + self.assertTrue(a_obj.homogenous()) + + a_obj = a_args([a_args([[object()]]) for _ in range(10)]) + self.assertFalse(a_obj.homogenous()) + + def test_a_args__is_equal_length(self): + from google.cloud.spanner_dbapi.parser import a_args + + a_obj = a_args([]) + self.assertTrue(a_obj._is_equal_length()) + + def test_values(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + from google.cloud.spanner_dbapi.parser import values + + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) + self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) + + def test_expect(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import EXPR + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi import exceptions + + with self.assertRaises(exceptions.ProgrammingError): + expect(word="", token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word="ABC", token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word="(", token=ARGS) + + expected = "", pyfmt_str + self.assertEqual(expect("%s", EXPR), expected) + + expected = expect("function()", FUNC) + self.assertEqual(expect("function()", EXPR), expected) + + with self.assertRaises(exceptions.ProgrammingError): + expect(word="", token="ABC") + + def test_expect_values(self): + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi.parser import values + + cases = [ + ("VALUES ()", "", values([a_args([])])), + ("VALUES", "", values([])), + ("VALUES(%s)", "", values([a_args([pyfmt_str])])), + (" VALUES (%s) ", "", values([a_args([pyfmt_str])])), + ("VALUES(%s, %s)", "", values([a_args([pyfmt_str, pyfmt_str])])), + ( + "VALUES(%s, %s, LOWER(%s, %s))", + "", + values( + [ + a_args( + [ + pyfmt_str, + pyfmt_str, + func("LOWER", a_args([pyfmt_str, pyfmt_str])), + ] + ) + ] + ), + ), + ( + "VALUES (UPPER(%s)), (%s)", + "", + values( + [ + a_args([func("UPPER", a_args([pyfmt_str]))]), + a_args([pyfmt_str]), + ] + ), + ), + ] + + for text, want_unconsumed, want_parsed in cases: + with self.subTest(text=text): + got_unconsumed, got_parsed = expect(text, VALUES) + self.assertEqual(got_parsed, want_parsed) + self.assertEqual(got_unconsumed, want_unconsumed) + + def test_expect_values_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import expect + + cases = [ + ("", "VALUES: `` does not start with VALUES"), + ( + "VALUES(%s, %s, (%s, %s))", + "FUNC: `\\(%s, %s\\)\\)` does not begin with `a-zA-z` nor a `_`", + ), + ("VALUES(%s),,", "ARGS: supposed to begin with `\\(` in `,`"), + ] + + for text, wantException in cases: + with self.subTest(text=text): + self.assertRaisesRegex( + ProgrammingError, + wantException, + lambda: expect(text, VALUES), + ) + + def test_as_values(self): + from google.cloud.spanner_dbapi.parser import as_values + + values = (1, 2) + with mock.patch( + "google.cloud.spanner_dbapi.parser.parse_values", + return_value=values, + ): + self.assertEqual(as_values(None), values[1]) diff --git a/tests/unit/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py new file mode 100644 index 0000000000..4246a43e45 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_types.py @@ -0,0 +1,63 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + +from time import timezone + + +class TestTypes(unittest.TestCase): + + TICKS = 1572822862.9782631 + timezone # Sun 03 Nov 2019 23:14:22 UTC + + def test__date_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + + actual = types._date_from_ticks(self.TICKS) + expected = datetime.date(2019, 11, 3) + + self.assertEqual(actual, expected) + + def test__time_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + + actual = types._time_from_ticks(self.TICKS) + expected = datetime.time(23, 14, 22) + + self.assertEqual(actual, expected) + + def test__timestamp_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + + actual = types._timestamp_from_ticks(self.TICKS) + expected = datetime.datetime(2019, 11, 3, 23, 14, 22) + + self.assertEqual(actual, expected) + + def test_type_equal(self): + from google.cloud.spanner_dbapi import types + + self.assertEqual(types.BINARY, "TYPE_CODE_UNSPECIFIED") + self.assertEqual(types.BINARY, "BYTES") + self.assertEqual(types.BINARY, "ARRAY") + self.assertEqual(types.BINARY, "STRUCT") + self.assertNotEqual(types.BINARY, "STRING") + + self.assertEqual(types.NUMBER, "BOOL") + self.assertEqual(types.NUMBER, "INT64") + self.assertEqual(types.NUMBER, "FLOAT64") + self.assertEqual(types.NUMBER, "NUMERIC") + self.assertNotEqual(types.NUMBER, "STRING") + + self.assertEqual(types.DATETIME, "TIMESTAMP") + self.assertEqual(types.DATETIME, "DATE") + self.assertNotEqual(types.DATETIME, "STRING") diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py new file mode 100644 index 0000000000..90e1b7cf04 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -0,0 +1,72 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + + +class TestUtils(unittest.TestCase): + def test_PeekIterator(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + cases = [ + ("list", [1, 2, 3, 4, 6, 7], [1, 2, 3, 4, 6, 7]), + ("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]), + ("tuple", ("a", 12, 0xFF), ["a", 12, 0xFF]), + ("iter_from_tuple", iter(("a", 12, 0xFF)), ["a", 12, 0xFF]), + ("no_args", (), []), + ] + + for name, data_in, expected in cases: + with self.subTest(name=name): + pitr = PeekIterator(data_in) + actual = list(pitr) + self.assertEqual(actual, expected) + + def test_peekIterator_list_rows_converted_to_tuples(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + # Cloud Spanner returns results in lists e.g. [result]. + # PeekIterator is used by BaseCursor in its fetch* methods. + # This test ensures that anything passed into PeekIterator + # will be returned as a tuple. + pit = PeekIterator([["a"], ["b"], ["c"], ["d"], ["e"]]) + got = list(pit) + want = [("a",), ("b",), ("c",), ("d",), ("e",)] + self.assertEqual( + got, want, "Rows of type list must be returned as tuples" + ) + + seventeen = PeekIterator([[17]]) + self.assertEqual(list(seventeen), [(17,)]) + + pit = PeekIterator([["%", "%d"]]) + self.assertEqual(next(pit), ("%", "%d")) + + pit = PeekIterator([("Clark", "Kent")]) + self.assertEqual(next(pit), ("Clark", "Kent")) + + def test_peekIterator_nonlist_rows_unconverted(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + pi = PeekIterator(["a", "b", "c", "d", "e"]) + got = list(pi) + want = ["a", "b", "c", "d", "e"] + self.assertEqual(got, want, "Values should be returned unchanged") + + def test_backtick_unicode(self): + from google.cloud.spanner_dbapi.utils import backtick_unicode + + cases = [ + ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), + ("SELECT (1) as föö", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), + ("SELECT (1) as `föö", "SELECT (1) as `föö"), + ] + for sql, want in cases: + with self.subTest(sql=sql): + got = backtick_unicode(sql) + self.assertEqual(got, want) From 33b3fef7b46caa3d03770f0ed3c9cd57c61e84a0 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 30 Oct 2020 19:49:18 -0400 Subject: [PATCH 02/12] chore: imports in test files rearranged --- tests/unit/spanner_dbapi/test__helpers.py | 3 +-- tests/unit/spanner_dbapi/test_connection.py | 3 +-- tests/unit/spanner_dbapi/test_cursor.py | 3 +-- tests/unit/spanner_dbapi/test_parser.py | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index e5316d254e..bdeb86a73c 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -6,10 +6,9 @@ """Cloud Spanner DB-API Connection class unit tests.""" +import mock import unittest -from unittest import mock - class TestHelpers(unittest.TestCase): def test__execute_insert_heterogenous(self): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index d545472c57..4a0d1a615f 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -6,11 +6,10 @@ """Cloud Spanner DB-API Connection class unit tests.""" +import mock import unittest import warnings -from unittest import mock - def _make_credentials(): from google.auth import credentials diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..d9b39ecb31 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -6,10 +6,9 @@ """Cursor() class unit tests.""" +import mock import unittest -from unittest import mock - class TestCursor(unittest.TestCase): diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index d5baf9d824..8b74e89d88 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -4,10 +4,9 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd +import mock import unittest -from unittest import mock - class TestParser(unittest.TestCase): def test_func(self): From 04ad15bb0a1783415df52dd64e5608feb7a93e5c Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 30 Oct 2020 21:12:35 -0400 Subject: [PATCH 03/12] chore: added coding directive --- tests/unit/spanner_dbapi/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 90e1b7cf04..86670b1995 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -4,6 +4,8 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd +# coding=utf-8 + import unittest From 38cfd7382bdee7c59f849e16a1b75214af8e1690 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 30 Oct 2020 22:36:58 -0400 Subject: [PATCH 04/12] chore: --- tests/unit/spanner_dbapi/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 86670b1995..c82eb60410 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -1,11 +1,10 @@ +# coding=utf-8 # Copyright 2020 Google LLC # # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -# coding=utf-8 - import unittest From bbc378abaf09269665c264b165739c07eb880757 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 31 Oct 2020 00:39:35 -0400 Subject: [PATCH 05/12] chore: encoding directive --- google/cloud/spanner_dbapi/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index f4769e80a4..97a33fc0cc 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright 2020 Google LLC # # Use of this source code is governed by a BSD-style From ecf2e49a9e4fa2a982f80472400acdadf8919488 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 31 Oct 2020 13:13:29 -0400 Subject: [PATCH 06/12] chore: skipping Python 2 incompatible tests --- tests/unit/spanner_dbapi/test_connection.py | 2 ++ tests/unit/spanner_dbapi/test_cursor.py | 5 +++++ tests/unit/spanner_dbapi/test_parse_utils.py | 12 ++++++++++++ tests/unit/spanner_dbapi/test_parser.py | 11 +++++++++++ tests/unit/spanner_dbapi/test_utils.py | 7 +++++++ 5 files changed, 37 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 4a0d1a615f..7080bd4f10 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -7,6 +7,7 @@ """Cloud Spanner DB-API Connection class unit tests.""" import mock +import sys import unittest import warnings @@ -42,6 +43,7 @@ def _make_connection(self): database = instance.database(self.DATABASE) return Connection(instance, database) + @unittest.skipIf(sys.version_info[0] < 3, 'Python 2 patching is outdated') def test_property_autocommit_setter(self): from google.cloud.spanner_dbapi import Connection diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index d9b39ecb31..74d935ebd0 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -7,6 +7,7 @@ """Cursor() class unit tests.""" import mock +import sys import unittest @@ -284,6 +285,10 @@ def test_executemany(self): (mock.call(operation, (1,)), mock.call(operation, (2,))) ) + @unittest.skipIf( + sys.version_info[0] < 3, + 'Python 2 has an outdated iterator definition' + ) def test_fetchone(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 1bd38c85eb..5e3b5a8e8f 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -4,12 +4,17 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd +import sys import unittest from google.cloud.spanner_v1 import param_types class TestParseUtils(unittest.TestCase): + + skip_condition = sys.version_info[0] < 3 + skip_message = 'Subtests are not supported in Python 2' + def test_classify_stmt(self): from google.cloud.spanner_dbapi.parse_utils import STMT_DDL from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT @@ -46,6 +51,7 @@ def test_classify_stmt(self): for query, want_class in cases: self.assertEqual(classify_stmt(query), want_class) + @unittest.skipIf(skip_condition, skip_message) def test_parse_insert(self): from google.cloud.spanner_dbapi.parse_utils import parse_insert from google.cloud.spanner_dbapi.exceptions import ProgrammingError @@ -176,6 +182,7 @@ def test_parse_insert(self): got, want, "Mismatch with parse_insert of `%s`" % sql ) + @unittest.skipIf(skip_condition, skip_message) def test_parse_insert_invalid(self): from google.cloud.spanner_dbapi import exceptions from google.cloud.spanner_dbapi.parse_utils import parse_insert @@ -206,6 +213,7 @@ def test_parse_insert_invalid(self): lambda: parse_insert(sql, params), ) + @unittest.skipIf(skip_condition, skip_message) def test_rows_for_insert_or_update(self): from google.cloud.spanner_dbapi.parse_utils import ( rows_for_insert_or_update, @@ -264,6 +272,7 @@ def test_rows_for_insert_or_update(self): got = rows_for_insert_or_update(columns, params, pyformat_args) self.assertEqual(got, want) + @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self): import decimal @@ -338,6 +347,7 @@ def test_sql_pyformat_args_to_spanner(self): got_named_args, want_named_args, "Named args do not match" ) + @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner_invalid(self): from google.cloud.spanner_dbapi import exceptions from google.cloud.spanner_dbapi.parse_utils import ( @@ -406,6 +416,7 @@ def test_get_param_types_none(self): self.assertEqual(get_param_types(None), None) + @unittest.skipIf(skip_condition, skip_message) def test_ensure_where_clause(self): from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause @@ -438,6 +449,7 @@ def test_ensure_where_clause(self): got = ensure_where_clause(sql) self.assertEqual(got, want) + @unittest.skipIf(skip_condition, skip_message) def test_escape_name(self): from google.cloud.spanner_dbapi.parse_utils import escape_name diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index 8b74e89d88..bdf4140d30 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -5,10 +5,16 @@ # https://developers.google.com/open-source/licenses/bsd import mock +import sys import unittest class TestParser(unittest.TestCase): + + skip_condition = sys.version_info[0] < 3 + skip_message = 'Subtests are not supported in Python 2' + + @unittest.skipIf(skip_condition, skip_message) def test_func(self): from google.cloud.spanner_dbapi.parser import FUNC from google.cloud.spanner_dbapi.parser import a_args @@ -55,6 +61,7 @@ def test_func(self): self.assertEqual(got_parsed, want_parsed) self.assertEqual(got_unconsumed, want_unconsumed) + @unittest.skipIf(skip_condition, skip_message) def test_func_fail(self): from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parser import FUNC @@ -92,6 +99,7 @@ def test_func_eq(self): func2.args = func1.args self.assertTrue(func1 == func2) + @unittest.skipIf(skip_condition, skip_message) def test_a_args(self): from google.cloud.spanner_dbapi.parser import ARGS from google.cloud.spanner_dbapi.parser import a_args @@ -123,6 +131,7 @@ def test_a_args(self): self.assertEqual(got_parsed, want_parsed) self.assertEqual(got_unconsumed, want_unconsumed) + @unittest.skipIf(skip_condition, skip_message) def test_a_args_fail(self): from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parser import ARGS @@ -207,6 +216,7 @@ def test_expect(self): with self.assertRaises(exceptions.ProgrammingError): expect(word="", token="ABC") + @unittest.skipIf(skip_condition, skip_message) def test_expect_values(self): from google.cloud.spanner_dbapi.parser import VALUES from google.cloud.spanner_dbapi.parser import a_args @@ -254,6 +264,7 @@ def test_expect_values(self): self.assertEqual(got_parsed, want_parsed) self.assertEqual(got_unconsumed, want_unconsumed) + @unittest.skipIf(skip_condition, skip_message) def test_expect_values_fail(self): from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parser import VALUES diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index c82eb60410..e3d0b4613c 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -5,10 +5,16 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd +import sys import unittest class TestUtils(unittest.TestCase): + + skip_condition = sys.version_info[0] < 3 + skip_message = 'Subtests are not supported in Python 2' + + @unittest.skipIf(skip_condition, skip_message) def test_PeekIterator(self): from google.cloud.spanner_dbapi.utils import PeekIterator @@ -57,6 +63,7 @@ def test_peekIterator_nonlist_rows_unconverted(self): want = ["a", "b", "c", "d", "e"] self.assertEqual(got, want, "Values should be returned unchanged") + @unittest.skipIf(skip_condition, skip_message) def test_backtick_unicode(self): from google.cloud.spanner_dbapi.utils import backtick_unicode From f69dc3bb6207a9600dff5fa515fe4e4119e0743f Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 31 Oct 2020 14:57:02 -0400 Subject: [PATCH 07/12] chore: skipping Python 2 incompatible tests --- tests/unit/spanner_dbapi/test_parser.py | 3 ++- tests/unit/spanner_dbapi/test_utils.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index bdf4140d30..5df0f2b4fb 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -190,7 +190,8 @@ def test_values(self): from google.cloud.spanner_dbapi.parser import values a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) - self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) + # self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) + self.assertEqual(str(values(a_obj)), "VALUES{}".format(a_obj)) def test_expect(self): from google.cloud.spanner_dbapi.parser import ARGS diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index e3d0b4613c..1547ce0cb3 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -32,6 +32,10 @@ def test_PeekIterator(self): actual = list(pitr) self.assertEqual(actual, expected) + @unittest.skipIf( + skip_condition, + 'Python 2 has an outdated iterator definition' + ) def test_peekIterator_list_rows_converted_to_tuples(self): from google.cloud.spanner_dbapi.utils import PeekIterator @@ -55,6 +59,10 @@ def test_peekIterator_list_rows_converted_to_tuples(self): pit = PeekIterator([("Clark", "Kent")]) self.assertEqual(next(pit), ("Clark", "Kent")) + @unittest.skipIf( + skip_condition, + 'Python 2 has an outdated iterator definition' + ) def test_peekIterator_nonlist_rows_unconverted(self): from google.cloud.spanner_dbapi.utils import PeekIterator From 403c0971d1b504abacee4d9de90c85d91ee131fe Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 31 Oct 2020 16:17:07 -0400 Subject: [PATCH 08/12] chore: skipping Python 2 incompatible tests --- tests/unit/spanner_dbapi/test_parser.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index 5df0f2b4fb..ae1d3e0da7 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -184,14 +184,21 @@ def test_a_args__is_equal_length(self): a_obj = a_args([]) self.assertTrue(a_obj._is_equal_length()) + @unittest.skipIf( + skip_condition, + 'Python 2 has an outdated iterator definition' + ) + @unittest.skipIf( + skip_condition, + 'Python 2 does not support 0-argument super() calls' + ) def test_values(self): from google.cloud.spanner_dbapi.parser import a_args from google.cloud.spanner_dbapi.parser import terminal from google.cloud.spanner_dbapi.parser import values a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) - # self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) - self.assertEqual(str(values(a_obj)), "VALUES{}".format(a_obj)) + self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) def test_expect(self): from google.cloud.spanner_dbapi.parser import ARGS From 541ce567da5734aaac72f33cbf16326eef796f73 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 31 Oct 2020 17:25:46 -0400 Subject: [PATCH 09/12] chore: skipping Python 2 incompatible tests --- tests/unit/spanner_dbapi/test_parse_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 5e3b5a8e8f..6e87b3d3e0 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -378,6 +378,7 @@ def test_cast_for_spanner(self): self.assertEqual(cast_for_spanner(5), 5) self.assertEqual(cast_for_spanner("string"), "string") + @unittest.skipIf(skip_condition, skip_message) def test_get_param_types(self): import datetime From 054e8414987a54267314efd35d26bb2dc359b56c Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 31 Oct 2020 18:35:55 -0400 Subject: [PATCH 10/12] chore: lint format --- google/cloud/spanner_dbapi/_helpers.py | 12 +--- google/cloud/spanner_dbapi/connection.py | 11 +--- google/cloud/spanner_dbapi/cursor.py | 8 +-- google/cloud/spanner_dbapi/parse_utils.py | 8 +-- google/cloud/spanner_dbapi/parser.py | 16 ++--- tests/unit/spanner_dbapi/test__helpers.py | 28 ++------- tests/unit/spanner_dbapi/test_connection.py | 49 +++++---------- tests/unit/spanner_dbapi/test_cursor.py | 41 ++++-------- tests/unit/spanner_dbapi/test_parse_utils.py | 65 +++++--------------- tests/unit/spanner_dbapi/test_parser.py | 33 +++------- tests/unit/spanner_dbapi/test_utils.py | 16 ++--- 11 files changed, 72 insertions(+), 215 deletions(-) diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index f581fdebbd..e8b981c4d0 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -51,9 +51,7 @@ def _execute_insert_heterogenous(transaction, sql_params_list): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) param_types = get_param_types(params) - res = transaction.execute_sql( - sql, params=params, param_types=param_types - ) + res = transaction.execute_sql(sql, params=params, param_types=param_types) # TODO: File a bug with Cloud Spanner and the Python client maintainers # about a lost commit when res isn't read from. _ = list(res) @@ -86,9 +84,7 @@ def handle_insert(connection, sql, params): if parts.get("homogenous"): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. - return connection.database.run_in_transaction( - _execute_insert_homogenous, parts - ) + return connection.database.run_in_transaction(_execute_insert_homogenous, parts) else: # All the other cases that are esoteric and need # transaction.execute_sql @@ -148,9 +144,7 @@ def __str__(self): "internal_size=%d" % self.internal_size if self.internal_size else None, - "precision='%s'" % self.precision - if self.precision - else None, + "precision='%s'" % self.precision if self.precision else None, "scale='%s'" % self.scale if self.scale else None, "null_ok='%s'" % self.null_ok if self.null_ok else None, ], diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index b572c8573b..45b69bc067 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -199,12 +199,7 @@ def __exit__(self, etype, value, traceback): def connect( - instance_id, - database_id, - project=None, - credentials=None, - pool=None, - user_agent=None, + instance_id, database_id, project=None, credentials=None, pool=None, user_agent=None ): """Creates a connection to a Google Cloud Spanner database. @@ -243,11 +238,11 @@ def connect( """ client_info = ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION ) client = spanner.Client( - project=project, credentials=credentials, client_info=client_info, + project=project, credentials=credentials, client_info=client_info ) instance = client.instance(instance_id) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6997752a42..96433c1d0c 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -160,9 +160,7 @@ def execute(self, sql, args=None): if not self.connection.autocommit: transaction = self.connection.transaction_checkout() - sql, params = parse_utils.sql_pyformat_args_to_spanner( - sql, args - ) + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, args) self._result_set = transaction.execute_sql( sql, params, param_types=get_param_types(params) @@ -309,9 +307,7 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): self.connection.run_prior_DDL_statements() with self.connection.database.snapshot() as snapshot: - res = snapshot.execute_sql( - sql, params=params, param_types=param_types - ) + res = snapshot.execute_sql(sql, params=params, param_types=param_types) return list(res) def get_table_column_schema(self, table_name): diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 084eea315e..8201f9a19c 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -271,8 +271,7 @@ def parse_insert(insert_sql, params): for item in after_values_sql: if item.count("%s") > 0: raise ProgrammingError( - 'no params yet there are %d "%%s" tokens' - % item.count("%s") + 'no params yet there are %d "%%s" tokens' % item.count("%s") ) insert_sql = sanitize_literals_for_upload(insert_sql) @@ -518,10 +517,7 @@ def ensure_where_clause(sql): Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. Add a dummy WHERE clause if necessary. """ - if any( - isinstance(token, sqlparse.sql.Where) - for token in sqlparse.parse(sql)[0] - ): + if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): return sql return sql + " WHERE 1=1" diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py index 2fc0156b57..074d733c72 100644 --- a/google/cloud/spanner_dbapi/parser.py +++ b/google/cloud/spanner_dbapi/parser.py @@ -78,9 +78,7 @@ def __repr__(self): return self.__str__() def has_expr(self): - return any( - [token for token in self.argv if not isinstance(token, terminal)] - ) + return any([token for token in self.argv if not isinstance(token, terminal)]) def __len__(self): return len(self.argv) @@ -150,9 +148,7 @@ def expect(word, token): word = word.strip() if token == VALUES: if not word.startswith("VALUES"): - raise ProgrammingError( - "VALUES: `%s` does not start with VALUES" % word - ) + raise ProgrammingError("VALUES: `%s` does not start with VALUES" % word) word = word[len("VALUES") :].lstrip() all_args = [] @@ -197,9 +193,7 @@ def expect(word, token): # (FUNC, %s...) # (%s, %s...) if not (word and word.startswith("(")): - raise ProgrammingError( - "ARGS: supposed to begin with `(` in `%s`" % word - ) + raise ProgrammingError("ARGS: supposed to begin with `(` in `%s`" % word) word = word[1:] @@ -223,9 +217,7 @@ def expect(word, token): word = word[1:] if not (word and word.startswith(")")): - raise ProgrammingError( - "ARGS: supposed to end with `)` in `%s`" % word - ) + raise ProgrammingError("ARGS: supposed to end with `)` in `%s`" % word) word = word[1:] return word, a_args(terms) diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index bdeb86a73c..c52c617543 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -21,8 +21,7 @@ def test__execute_insert_heterogenous(self): return_value=params, ) as mock_pyformat: with mock.patch( - "google.cloud.spanner_dbapi._helpers.get_param_types", - return_value=None, + "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None ) as mock_param_types: transaction = mock.MagicMock() transaction.execute_sql = mock_execute = mock.MagicMock() @@ -30,9 +29,7 @@ def test__execute_insert_heterogenous(self): mock_pyformat.assert_called_once_with(params[0], params[1]) mock_param_types.assert_called_once_with(None) - mock_execute.assert_called_once_with( - sql, params=None, param_types=None - ) + mock_execute.assert_called_once_with(sql, params=None, param_types=None) def test__execute_insert_homogenous(self): from google.cloud.spanner_dbapi import _helpers @@ -53,8 +50,7 @@ def test_handle_insert(self): sql = "sql" parts = mock.MagicMock() with mock.patch( - "google.cloud.spanner_dbapi._helpers.parse_insert", - return_value=parts, + "google.cloud.spanner_dbapi._helpers.parse_insert", return_value=parts ): parts.get = mock.MagicMock(return_value=True) mock_run_in.return_value = 0 @@ -80,13 +76,7 @@ def test_ctor(self): null_ok = False cols = ColumnInfo( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, + name, type_code, display_size, internal_size, precision, scale, null_ok ) self.assertEqual(cols.name, name) @@ -98,15 +88,7 @@ def test_ctor(self): self.assertEqual(cols.null_ok, null_ok) self.assertEqual( cols.fields, - ( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ), + (name, type_code, display_size, internal_size, precision, scale, null_ok), ) def test___get_item__(self): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 7080bd4f10..bd9dd80c8c 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -43,7 +43,7 @@ def _make_connection(self): database = instance.database(self.DATABASE) return Connection(instance, database) - @unittest.skipIf(sys.version_info[0] < 3, 'Python 2 patching is outdated') + @unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated") def test_property_autocommit_setter(self): from google.cloud.spanner_dbapi import Connection @@ -80,13 +80,9 @@ def test_property_instance(self): def test__session_checkout(self): from google.cloud.spanner_dbapi import Connection - with mock.patch( - "google.cloud.spanner_v1.database.Database", - ) as mock_database: + with mock.patch("google.cloud.spanner_v1.database.Database") as mock_database: mock_database._pool = mock.MagicMock() - mock_database._pool.get = mock.MagicMock( - return_value="db_session_pool" - ) + mock_database._pool.get = mock.MagicMock(return_value="db_session_pool") connection = Connection(self.INSTANCE, mock_database) connection._session_checkout() @@ -100,9 +96,7 @@ def test__session_checkout(self): def test__release_session(self): from google.cloud.spanner_dbapi import Connection - with mock.patch( - "google.cloud.spanner_v1.database.Database", - ) as mock_database: + with mock.patch("google.cloud.spanner_v1.database.Database") as mock_database: mock_database._pool = mock.MagicMock() mock_database._pool.put = mock.MagicMock() connection = Connection(self.INSTANCE, mock_database) @@ -116,9 +110,7 @@ def test_transaction_checkout(self): from google.cloud.spanner_dbapi import Connection connection = Connection(self.INSTANCE, self.DATABASE) - connection._session_checkout = mock_checkout = mock.MagicMock( - autospec=True - ) + connection._session_checkout = mock_checkout = mock.MagicMock(autospec=True) connection.transaction_checkout() mock_checkout.assert_called_once_with() @@ -133,12 +125,10 @@ def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, + "google.cloud.spanner_v1.database.Database.exists", return_value=True ): connection = connect("test-instance", "test-database") @@ -158,9 +148,7 @@ def test_close(self): @mock.patch.object(warnings, "warn") def test_commit(self, mock_warn): from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import ( - AUTOCOMMIT_MODE_WARNING, - ) + from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING connection = Connection(self.INSTANCE, self.DATABASE) @@ -189,9 +177,7 @@ def test_commit(self, mock_warn): @mock.patch.object(warnings, "warn") def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import ( - AUTOCOMMIT_MODE_WARNING, - ) + from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING connection = Connection(self.INSTANCE, self.DATABASE) @@ -221,7 +207,7 @@ def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError with mock.patch( - "google.cloud.spanner_v1.database.Database", autospec=True, + "google.cloud.spanner_v1.database.Database", autospec=True ) as mock_database: connection = Connection(self.INSTANCE, mock_database) @@ -269,8 +255,7 @@ def test_connect_instance_not_found(self): from google.cloud.spanner_dbapi import connect with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=False, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=False ): with self.assertRaises(ValueError): connect("test-instance", "test-database") @@ -279,12 +264,10 @@ def test_connect_database_not_found(self): from google.cloud.spanner_dbapi import connect with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=False, + "google.cloud.spanner_v1.database.Database.exists", return_value=False ): with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): with self.assertRaises(ValueError): connect("test-instance", "test-database") @@ -294,8 +277,7 @@ def test_default_sessions_pool(self): with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): connection = connect("test-instance", "test-database") @@ -312,8 +294,7 @@ def test_sessions_pool(self): "google.cloud.spanner_v1.instance.Instance.database" ) as database_mock: with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): connect("test-instance", database_id, pool=pool) database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 74d935ebd0..78ac98bc9f 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -66,12 +66,10 @@ def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, + "google.cloud.spanner_v1.database.Database.exists", return_value=True ): connection = connect(self.INSTANCE, self.DATABASE) @@ -94,7 +92,7 @@ def test_do_execute_update(self): def run_helper(ret_value): transaction.execute_update.return_value = ret_value res = cursor._do_execute_update( - transaction=transaction, sql="sql", params=None, + transaction=transaction, sql="sql", params=None ) return res @@ -172,9 +170,7 @@ def test_execute_statement(self): ) as mock_handle_insert: sql = "sql" cursor.execute(sql=sql) - mock_handle_insert.assert_called_once_with( - connection, sql, None - ) + mock_handle_insert.assert_called_once_with(connection, sql, None) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", @@ -184,9 +180,7 @@ def test_execute_statement(self): mock_db.run_in_transaction = mock_run_in = mock.MagicMock() sql = "sql" cursor.execute(sql=sql) - mock_run_in.assert_called_once_with( - cursor._do_execute_update, sql, None - ) + mock_run_in.assert_called_once_with(cursor._do_execute_update, sql, None) def test_execute_integrity_error(self): from google.api_core import exceptions @@ -242,12 +236,10 @@ def test_executemany_on_closed_cursor(self): from google.cloud.spanner_dbapi import connect with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, + "google.cloud.spanner_v1.database.Database.exists", return_value=True ): connection = connect("test-instance", "test-database") @@ -255,9 +247,7 @@ def test_executemany_on_closed_cursor(self): cursor.close() with self.assertRaises(InterfaceError): - cursor.executemany( - """SELECT * FROM table1 WHERE "col1" = @a1""", () - ) + cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ()) def test_executemany(self): from google.cloud.spanner_dbapi import connect @@ -266,12 +256,10 @@ def test_executemany(self): params_seq = ((1,), (2,)) with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, + "google.cloud.spanner_v1.database.Database.exists", return_value=True ): connection = connect("test-instance", "test-database") @@ -286,8 +274,7 @@ def test_executemany(self): ) @unittest.skipIf( - sys.version_info[0] < 3, - 'Python 2 has an outdated iterator definition' + sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) def test_fetchone(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) @@ -445,11 +432,7 @@ def test_get_table_column_schema(self): is_nullable = "YES" spanner_type = "spanner_type" rows = [(column_name, is_nullable, spanner_type)] - expected = { - column_name: ColumnDetails( - null_ok=True, spanner_type=spanner_type, - ) - } + expected = {column_name: ColumnDetails(null_ok=True, spanner_type=spanner_type)} with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", return_value=rows, diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 6e87b3d3e0..4417e7e0c0 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -13,7 +13,7 @@ class TestParseUtils(unittest.TestCase): skip_condition = sys.version_info[0] < 3 - skip_message = 'Subtests are not supported in Python 2' + skip_message = "Subtests are not supported in Python 2" def test_classify_stmt(self): from google.cloud.spanner_dbapi.parse_utils import STMT_DDL @@ -114,18 +114,9 @@ def test_parse_insert(self): (1, 2, 3, 4, 5, 6, 7, 8, 9), { "sql_params_list": [ - ( - "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", - (1, 2, 3), - ), - ( - "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", - (4, 5, 6), - ), - ( - "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", - (7, 8, 9), - ), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (1, 2, 3)), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (4, 5, 6)), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (7, 8, 9)), ] }, ), @@ -143,11 +134,7 @@ def test_parse_insert(self): ( "INSERT INTO T (f1, f2) VALUES (1, 2)", None, - { - "sql_params_list": [ - ("INSERT INTO T (f1, f2) VALUES (1, 2)", None) - ] - }, + {"sql_params_list": [("INSERT INTO T (f1, f2) VALUES (1, 2)", None)]}, ), ( "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)", @@ -158,14 +145,8 @@ def test_parse_insert(self): "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))", (1, "FOO"), ), - ( - "INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", - (5, 10), - ), - ( - "INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", - (11, 29), - ), + ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (5, 10)), + ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (11, 29)), ] }, ), @@ -178,9 +159,7 @@ def test_parse_insert(self): for sql, params, want in cases: with self.subTest(sql=sql): got = parse_insert(sql, params) - self.assertEqual( - got, want, "Mismatch with parse_insert of `%s`" % sql - ) + self.assertEqual(got, want, "Mismatch with parse_insert of `%s`" % sql) @unittest.skipIf(skip_condition, skip_message) def test_parse_insert_invalid(self): @@ -215,9 +194,7 @@ def test_parse_insert_invalid(self): @unittest.skipIf(skip_condition, skip_message) def test_rows_for_insert_or_update(self): - from google.cloud.spanner_dbapi.parse_utils import ( - rows_for_insert_or_update, - ) + from google.cloud.spanner_dbapi.parse_utils import rows_for_insert_or_update from google.cloud.spanner_dbapi.exceptions import Error with self.assertRaises(Error): @@ -259,12 +236,7 @@ def test_rows_for_insert_or_update(self): ("fp", "cp", "o", "f3"), ], ), - ( - ["app", "name", "fn"], - ["ap", "n", "f1"], - None, - [("ap", "n", "f1")], - ), + (["app", "name", "fn"], ["ap", "n", "f1"], None, [("ap", "n", "f1")]), ] for i, (columns, params, pyformat_args, want) in enumerate(cases): @@ -276,9 +248,7 @@ def test_rows_for_insert_or_update(self): def test_sql_pyformat_args_to_spanner(self): import decimal - from google.cloud.spanner_dbapi.parse_utils import ( - sql_pyformat_args_to_spanner, - ) + from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner cases = [ ( @@ -313,10 +283,7 @@ def test_sql_pyformat_args_to_spanner(self): ), ( # Intentionally using a dict with more keys than will be resolved. - ( - "SELECT * from t WHERE f1=%(f1)s", - {"f1": "app", "f2": "name"}, - ), + ("SELECT * from t WHERE f1=%(f1)s", {"f1": "app", "f2": "name"}), ("SELECT * from t WHERE f1=@a0", {"a0": "app"}), ), ( @@ -338,9 +305,7 @@ def test_sql_pyformat_args_to_spanner(self): ] for ((sql_in, params), sql_want) in cases: with self.subTest(sql=sql_in): - got_sql, got_named_args = sql_pyformat_args_to_spanner( - sql_in, params - ) + got_sql, got_named_args = sql_pyformat_args_to_spanner(sql_in, params) want_sql, want_named_args = sql_want self.assertEqual(got_sql, want_sql, "SQL does not match") self.assertEqual( @@ -350,9 +315,7 @@ def test_sql_pyformat_args_to_spanner(self): @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner_invalid(self): from google.cloud.spanner_dbapi import exceptions - from google.cloud.spanner_dbapi.parse_utils import ( - sql_pyformat_args_to_spanner, - ) + from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner cases = [ ( diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index ae1d3e0da7..b203328024 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -12,7 +12,7 @@ class TestParser(unittest.TestCase): skip_condition = sys.version_info[0] < 3 - skip_message = 'Subtests are not supported in Python 2' + skip_message = "Subtests are not supported in Python 2" @unittest.skipIf(skip_condition, skip_message) def test_func(self): @@ -42,10 +42,7 @@ def test_func(self): [ pyfmt_str, pyfmt_str, - func( - "TAN", - a_args([pyfmt_str, pyfmt_str]), - ), + func("TAN", a_args([pyfmt_str, pyfmt_str])), ] ), ), @@ -116,11 +113,7 @@ def test_a_args(self): "(%s,%s, f1(%s, %s))", "", a_args( - [ - pyfmt_str, - pyfmt_str, - func("f1", a_args([pyfmt_str, pyfmt_str])), - ] + [pyfmt_str, pyfmt_str, func("f1", a_args([pyfmt_str, pyfmt_str]))] ), ), ] @@ -184,13 +177,9 @@ def test_a_args__is_equal_length(self): a_obj = a_args([]) self.assertTrue(a_obj._is_equal_length()) + @unittest.skipIf(skip_condition, "Python 2 has an outdated iterator definition") @unittest.skipIf( - skip_condition, - 'Python 2 has an outdated iterator definition' - ) - @unittest.skipIf( - skip_condition, - 'Python 2 does not support 0-argument super() calls' + skip_condition, "Python 2 does not support 0-argument super() calls" ) def test_values(self): from google.cloud.spanner_dbapi.parser import a_args @@ -258,10 +247,7 @@ def test_expect_values(self): "VALUES (UPPER(%s)), (%s)", "", values( - [ - a_args([func("UPPER", a_args([pyfmt_str]))]), - a_args([pyfmt_str]), - ] + [a_args([func("UPPER", a_args([pyfmt_str]))]), a_args([pyfmt_str])] ), ), ] @@ -290,9 +276,7 @@ def test_expect_values_fail(self): for text, wantException in cases: with self.subTest(text=text): self.assertRaisesRegex( - ProgrammingError, - wantException, - lambda: expect(text, VALUES), + ProgrammingError, wantException, lambda: expect(text, VALUES) ) def test_as_values(self): @@ -300,7 +284,6 @@ def test_as_values(self): values = (1, 2) with mock.patch( - "google.cloud.spanner_dbapi.parser.parse_values", - return_value=values, + "google.cloud.spanner_dbapi.parser.parse_values", return_value=values ): self.assertEqual(as_values(None), values[1]) diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 1547ce0cb3..64a6130aa0 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -12,7 +12,7 @@ class TestUtils(unittest.TestCase): skip_condition = sys.version_info[0] < 3 - skip_message = 'Subtests are not supported in Python 2' + skip_message = "Subtests are not supported in Python 2" @unittest.skipIf(skip_condition, skip_message) def test_PeekIterator(self): @@ -32,10 +32,7 @@ def test_PeekIterator(self): actual = list(pitr) self.assertEqual(actual, expected) - @unittest.skipIf( - skip_condition, - 'Python 2 has an outdated iterator definition' - ) + @unittest.skipIf(skip_condition, "Python 2 has an outdated iterator definition") def test_peekIterator_list_rows_converted_to_tuples(self): from google.cloud.spanner_dbapi.utils import PeekIterator @@ -46,9 +43,7 @@ def test_peekIterator_list_rows_converted_to_tuples(self): pit = PeekIterator([["a"], ["b"], ["c"], ["d"], ["e"]]) got = list(pit) want = [("a",), ("b",), ("c",), ("d",), ("e",)] - self.assertEqual( - got, want, "Rows of type list must be returned as tuples" - ) + self.assertEqual(got, want, "Rows of type list must be returned as tuples") seventeen = PeekIterator([[17]]) self.assertEqual(list(seventeen), [(17,)]) @@ -59,10 +54,7 @@ def test_peekIterator_list_rows_converted_to_tuples(self): pit = PeekIterator([("Clark", "Kent")]) self.assertEqual(next(pit), ("Clark", "Kent")) - @unittest.skipIf( - skip_condition, - 'Python 2 has an outdated iterator definition' - ) + @unittest.skipIf(skip_condition, "Python 2 has an outdated iterator definition") def test_peekIterator_nonlist_rows_unconverted(self): from google.cloud.spanner_dbapi.utils import PeekIterator From 4a49b1f5ef8469f092ebaa7e0b70367feb508520 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Thu, 5 Nov 2020 17:40:05 -0500 Subject: [PATCH 11/12] chore: license headers updated --- google/cloud/spanner_dbapi/__init__.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/_helpers.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/connection.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/cursor.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/exceptions.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/parse_utils.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/parser.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/types.py | 16 ++++++++++++---- google/cloud/spanner_dbapi/utils.py | 17 ++++++++++++----- google/cloud/spanner_dbapi/version.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/__init__.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test__helpers.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_connection.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_cursor.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_globals.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_parse_utils.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_parser.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_types.py | 16 ++++++++++++---- tests/unit/spanner_dbapi/test_utils.py | 17 ++++++++++++----- 19 files changed, 228 insertions(+), 78 deletions(-) diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index 7695c0058f..e94ecdc0ed 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Connection-based DB API for Cloud Spanner.""" diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index e8b981c4d0..b7b965fcd7 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.parse_utils import parse_insert diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 45b69bc067..befc760ea5 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """DB-API Connection for the Google Cloud Spanner.""" diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 96433c1d0c..ceaccccdf3 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Database cursor for Google Cloud Spanner DB-API.""" diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py index b21be2c949..1a9fdd3625 100644 --- a/google/cloud/spanner_dbapi/exceptions.py +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Spanner DB API exceptions.""" diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 8201f9a19c..aeed4feb2c 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. "SQL parsing and classification utils." diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py index 074d733c72..9271631b25 100644 --- a/google/cloud/spanner_dbapi/parser.py +++ b/google/cloud/spanner_dbapi/parser.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Grammar for parsing VALUES: diff --git a/google/cloud/spanner_dbapi/types.py b/google/cloud/spanner_dbapi/types.py index 8c6bd27577..80d7030402 100644 --- a/google/cloud/spanner_dbapi/types.py +++ b/google/cloud/spanner_dbapi/types.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Implementation of the type objects and constructors according to the PEP-0249 specification. diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index 97a33fc0cc..b0ad3922a5 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -1,9 +1,16 @@ -# coding=utf-8 -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import re diff --git a/google/cloud/spanner_dbapi/version.py b/google/cloud/spanner_dbapi/version.py index 88d8f7cdaf..b0e48cff0b 100644 --- a/google/cloud/spanner_dbapi/version.py +++ b/google/cloud/spanner_dbapi/version.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import platform diff --git a/tests/unit/spanner_dbapi/__init__.py b/tests/unit/spanner_dbapi/__init__.py index 6b607710ed..377df12f71 100644 --- a/tests/unit/spanner_dbapi/__init__.py +++ b/tests/unit/spanner_dbapi/__init__.py @@ -1,5 +1,13 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index c52c617543..84c1bae3b1 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Cloud Spanner DB-API Connection class unit tests.""" diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index bd9dd80c8c..8cd3bced16 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Cloud Spanner DB-API Connection class unit tests.""" diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 78ac98bc9f..23ed5010d1 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Cursor() class unit tests.""" diff --git a/tests/unit/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py index 3f8360e2ea..2960862ec3 100644 --- a/tests/unit/spanner_dbapi/test_globals.py +++ b/tests/unit/spanner_dbapi/test_globals.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 4417e7e0c0..d411f2425b 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys import unittest diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index b203328024..2343800489 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import mock import sys diff --git a/tests/unit/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py index 4246a43e45..8c9dbe6c2b 100644 --- a/tests/unit/spanner_dbapi/test_types.py +++ b/tests/unit/spanner_dbapi/test_types.py @@ -1,8 +1,16 @@ -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 64a6130aa0..4fe94f30a7 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -1,9 +1,16 @@ -# coding=utf-8 -# Copyright 2020 Google LLC +# Copyright 2020 Google LLC All rights reserved. # -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys import unittest From e3d12cad8218a3e5ff2be9405131368e693b8fc3 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Tue, 10 Nov 2020 00:04:54 -0500 Subject: [PATCH 12/12] chore: minor fixes --- google/cloud/spanner_dbapi/_helpers.py | 5 +---- google/cloud/spanner_dbapi/parse_utils.py | 2 +- tests/unit/spanner_dbapi/test__helpers.py | 2 +- tests/unit/spanner_dbapi/test_parse_utils.py | 7 ++++--- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index b7b965fcd7..2fcdd59137 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -59,10 +59,7 @@ def _execute_insert_heterogenous(transaction, sql_params_list): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) param_types = get_param_types(params) - res = transaction.execute_sql(sql, params=params, param_types=param_types) - # TODO: File a bug with Cloud Spanner and the Python client maintainers - # about a lost commit when res isn't read from. - _ = list(res) + transaction.execute_update(sql, params=params, param_types=param_types) def _execute_insert_homogenous(transaction, parts): diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index aeed4feb2c..d88dcafb0d 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -494,7 +494,7 @@ def cast_for_spanner(value): :returns: Value converted to a Cloud Spanner type. """ if isinstance(value, decimal.Decimal): - return float(value) + return str(value) return value diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index 84c1bae3b1..84d6b3e323 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -32,7 +32,7 @@ def test__execute_insert_heterogenous(self): "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None ) as mock_param_types: transaction = mock.MagicMock() - transaction.execute_sql = mock_execute = mock.MagicMock() + transaction.execute_update = mock_execute = mock.MagicMock() _helpers._execute_insert_heterogenous(transaction, [params]) mock_pyformat.assert_called_once_with(params[0], params[1]) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index d411f2425b..a79ad8dc51 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -307,7 +307,7 @@ def test_sql_pyformat_args_to_spanner(self): ), ( "SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2", - {"a0": 1, "a1": 1.0, "a2": 31.0}, + {"a0": 1, "a1": 1.0, "a2": str(31)}, ), ), ] @@ -344,8 +344,9 @@ def test_cast_for_spanner(self): from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner - value = decimal.Decimal(3) - self.assertEqual(cast_for_spanner(value), float(3.0)) + dec = 3 + value = decimal.Decimal(dec) + self.assertEqual(cast_for_spanner(value), str(dec)) self.assertEqual(cast_for_spanner(5), 5) self.assertEqual(cast_for_spanner("string"), "string")