From 3d082e408a06c88b08a14cac37ce27fdab777088 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 21 Aug 2020 12:02:59 +0300 Subject: [PATCH 01/12] feat: refactor connect() function, cover it with unit tests --- spanner_dbapi/__init__.py | 149 ++++++++++++++++++---------- tests/spanner_dbapi/test_connect.py | 116 ++++++++++++++++++++++ 2 files changed, 210 insertions(+), 55 deletions(-) create mode 100644 tests/spanner_dbapi/test_connect.py diff --git a/spanner_dbapi/__init__.py b/spanner_dbapi/__init__.py index f5d349a655..cf88da598f 100644 --- a/spanner_dbapi/__init__.py +++ b/spanner_dbapi/__init__.py @@ -4,83 +4,122 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from google.cloud import spanner_v1 as spanner +"""Connection-based DB API for Cloud Spanner.""" + +from google.cloud import spanner_v1 from .connection import Connection -# These need to be included in the top-level package for PEP-0249 DB API v2. from .exceptions import ( - DatabaseError, DataError, Error, IntegrityError, InterfaceError, - InternalError, NotSupportedError, OperationalError, ProgrammingError, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, Warning, ) from .parse_utils import get_param_types from .types import ( - BINARY, DATETIME, NUMBER, ROWID, STRING, Binary, Date, DateFromTicks, Time, - TimeFromTicks, Timestamp, TimestampFromTicks, + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + Binary, + Date, + DateFromTicks, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, ) from .version import google_client_info -# Globals that MUST be defined ### -apilevel = "2.0" # Implements the Python Database API specification 2.0 version. -# We accept arguments in the format '%s' aka ANSI C print codes. -# as per https://www.python.org/dev/peps/pep-0249/#paramstyle -paramstyle = 'format' -# Threads may share the module but not connections. This is a paranoid threadsafety level, -# but it is necessary for starters to use when debugging failures. Eventually once transactions -# are working properly, we'll update the threadsafety level. +apilevel = "2.0" # supports DP-API 2.0 level. +paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. + +# Threads may share the module, but not connections. This is a paranoid threadsafety +# level, but it is necessary for starters to use when debugging failures. +# Eventually once transactions are working properly, we'll update the +# threadsafety level. threadsafety = 1 -def connect(project=None, instance=None, database=None, credentials_uri=None, user_agent=None): +def connect(instance_id, database_id, project=None, credentials=None, user_agent=None): """ - Connect to Cloud Spanner. + Create a connection to Cloud Spanner database. - Args: - project: The id of a project that already exists. - instance: The id of an instance that already exists. - database: The name of a database that already exists. - credentials_uri: An optional string specifying where to retrieve the service - account JSON for the credentials to connect to Cloud Spanner. + :type instance_id: :class:`str` + :param instance_id: ID of the instance to connect to. - Returns: - The Connection object associated to the Cloud Spanner instance. + :type database_id: :class:`str` + :param database_id: The name of the database to connect to. - Raises: - Error if it encounters any unexpected inputs. - """ - if not project: - raise Error("'project' is required.") - if not instance: - raise Error("'instance' is required.") - if not database: - raise Error("'database' is required.") + :type project: :class:`str` + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. - client_kwargs = { - 'project': project, - 'client_info': google_client_info(user_agent), - } - if credentials_uri: - client = spanner.Client.from_service_account_json(credentials_uri, **client_kwargs) - else: - client = spanner.Client(**client_kwargs) + :type credentials: :class:`google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to requests. + These credentials identify this application to the service. + If none are specified, the client will attempt to ascertain + the credentials from the environment. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Cloud Spanner resource. + + :raises: :class:`ProgrammingError` in case of given instance/database + doesn't exist. + """ + client = spanner_v1.Client( + project=project, + credentials=credentials, + client_info=google_client_info(user_agent), + ) - client_instance = client.instance(instance) - if not client_instance.exists(): - raise ProgrammingError("instance '%s' does not exist." % instance) + instance = client.instance(instance_id) + if not instance.exists(): + raise ProgrammingError("instance '%s' does not exist." % instance_id) - db = client_instance.database(database, pool=spanner.pool.BurstyPool()) - if not db.exists(): - raise ProgrammingError("database '%s' does not exist." % database) + database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool()) + if not database.exists(): + raise ProgrammingError("database '%s' does not exist." % database_id) - return Connection(db) + return Connection(database) __all__ = [ - 'DatabaseError', 'DataError', 'Error', 'IntegrityError', 'InterfaceError', - 'InternalError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Warning', 'DEFAULT_USER_AGENT', 'apilevel', 'connect', 'paramstyle', 'threadsafety', - 'get_param_types', - 'Binary', 'Date', 'DateFromTicks', 'Time', 'TimeFromTicks', 'Timestamp', - 'TimestampFromTicks', - 'BINARY', 'STRING', 'NUMBER', 'DATETIME', 'ROWID', 'TimestampStr', + "DatabaseError", + "DataError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "Warning", + "DEFAULT_USER_AGENT", + "apilevel", + "connect", + "paramstyle", + "threadsafety", + "get_param_types", + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "STRING", + "NUMBER", + "DATETIME", + "ROWID", + "TimestampStr", ] diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py new file mode 100644 index 0000000000..faa3c84fe1 --- /dev/null +++ b/tests/spanner_dbapi/test_connect.py @@ -0,0 +1,116 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""connect() module function unit tests.""" + +import mock +import unittest + + +def _make_credentials(): + import google.auth.credentials + + class _CredentialsWithScopes( + google.auth.credentials.Credentials, google.auth.credentials.Scoped + ): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Testconnect(unittest.TestCase): + def _callFUT(self, *args, **kw): + from google.cloud.spanner_dbapi import connect + + return connect(*args, **kw) + + def test_connect(self): + from google.api_core.gapic_v1.client_info import ClientInfo + from google.cloud.spanner_dbapi.connection import Connection + + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) + + with mock.patch("google.cloud.spanner_dbapi.spanner_v1.Client") as client_mock: + with mock.patch( + "google.cloud.spanner_dbapi.google_client_info", + return_value=CLIENT_INFO, + ) as client_info_mock: + + connection = self._callFUT( + "test-instance", "test-database", PROJECT, CREDENTIALS, USER_AGENT + ) + + self.assertIsInstance(connection, Connection) + client_info_mock.assert_called_once_with(USER_AGENT) + + client_mock.assert_called_once_with( + project=PROJECT, credentials=CREDENTIALS, client_info=CLIENT_INFO + ) + + def test_instance_not_found(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=False + ) as exists_mock: + + with self.assertRaises(ProgrammingError): + self._callFUT("test-instance", "test-database") + + exists_mock.assert_called_once() + + def test_database_not_found(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=False + ) as exists_mock: + + with self.assertRaises(ProgrammingError): + self._callFUT("test-instance", "test-database") + + exists_mock.assert_called_once() + + def test_connect_instance_id(self): + from google.cloud.spanner_dbapi.connection import Connection + + INSTANCE = "test-instance" + + with mock.patch( + "google.cloud.spanner_v1.client.Client.instance" + ) as instance_mock: + connection = self._callFUT(INSTANCE, "test-database") + + instance_mock.assert_called_once_with(INSTANCE) + + self.assertIsInstance(connection, Connection) + + def test_connect_database_id(self): + from google.cloud.spanner_dbapi.connection import Connection + + DATABASE = "test-database" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + connection = self._callFUT("test-instance", DATABASE) + + database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + + self.assertIsInstance(connection, Connection) + + +if __name__ == "__main__": + unittest.main() From 13d672b156ab7d11bb213c23815931e0e359e5cb Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 21 Aug 2020 12:19:22 +0300 Subject: [PATCH 02/12] fix mock import --- tests/spanner_dbapi/test_connect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index faa3c84fe1..5d7f193b58 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -6,8 +6,8 @@ """connect() module function unit tests.""" -import mock import unittest +from unittest import mock def _make_credentials(): From 6eaa1f0848573e793e8b2385deab02ec493f04ec Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 21 Aug 2020 12:27:00 +0300 Subject: [PATCH 03/12] change imports to the db_api package instead of google.cloud --- tests/spanner_dbapi/test_connect.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index 5d7f193b58..a57d3ca0ff 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -23,23 +23,22 @@ class _CredentialsWithScopes( class Testconnect(unittest.TestCase): def _callFUT(self, *args, **kw): - from google.cloud.spanner_dbapi import connect + from spanner_dbapi import connect return connect(*args, **kw) def test_connect(self): from google.api_core.gapic_v1.client_info import ClientInfo - from google.cloud.spanner_dbapi.connection import Connection + from spanner_dbapi.connection import Connection PROJECT = "test-project" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) - with mock.patch("google.cloud.spanner_dbapi.spanner_v1.Client") as client_mock: + with mock.patch("spanner_dbapi.spanner_v1.Client") as client_mock: with mock.patch( - "google.cloud.spanner_dbapi.google_client_info", - return_value=CLIENT_INFO, + "spanner_dbapi.google_client_info", return_value=CLIENT_INFO ) as client_info_mock: connection = self._callFUT( @@ -54,7 +53,7 @@ def test_connect(self): ) def test_instance_not_found(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from spanner_dbapi.exceptions import ProgrammingError with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=False @@ -66,7 +65,7 @@ def test_instance_not_found(self): exists_mock.assert_called_once() def test_database_not_found(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from spanner_dbapi.exceptions import ProgrammingError with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True @@ -81,7 +80,7 @@ def test_database_not_found(self): exists_mock.assert_called_once() def test_connect_instance_id(self): - from google.cloud.spanner_dbapi.connection import Connection + from spanner_dbapi.connection import Connection INSTANCE = "test-instance" @@ -95,7 +94,7 @@ def test_connect_instance_id(self): self.assertIsInstance(connection, Connection) def test_connect_database_id(self): - from google.cloud.spanner_dbapi.connection import Connection + from spanner_dbapi.connection import Connection DATABASE = "test-database" From 7607d9e2aae297a0c8c4723ecc8856382de194f3 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 24 Aug 2020 12:47:57 +0300 Subject: [PATCH 04/12] feat: cursor must detect if the parent connection is closed --- spanner_dbapi/connection.py | 54 ++++++----- spanner_dbapi/cursor.py | 138 ++++++++++++++++++----------- tests/spanner_dbapi/test_cursor.py | 54 +++++++++++ 3 files changed, 172 insertions(+), 74 deletions(-) create mode 100644 tests/spanner_dbapi/test_cursor.py diff --git a/spanner_dbapi/connection.py b/spanner_dbapi/connection.py index 20b707adb0..0ae8c84d27 100644 --- a/spanner_dbapi/connection.py +++ b/spanner_dbapi/connection.py @@ -11,26 +11,31 @@ from .cursor import Cursor from .exceptions import InterfaceError -ColumnDetails = namedtuple('column_details', ['null_ok', 'spanner_type']) +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) class Connection: def __init__(self, db_handle): self._dbhandle = db_handle - self._closed = False self._ddl_statements = [] + self.is_closed = False + def cursor(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() return Cursor(self) - def __raise_if_already_closed(self): - """ - Raise an exception if attempting to use an already closed connection. + def _raise_if_already_closed(self): + """Raise an exception if this connection is closed. + + Helper to check the connection state before + running a SQL/DDL/DML query. + + :raises: :class:`InterfaceError` if this connection is closed. """ - if self._closed: - raise InterfaceError('connection already closed') + if self.is_closed: + raise InterfaceError("connection is already closed") def __handle_update_ddl(self, ddl_statements): """ @@ -41,24 +46,24 @@ def __handle_update_ddl(self, ddl_statements): Returns: google.api_core.operation.Operation.result() """ - self.__raise_if_already_closed() + self._raise_if_already_closed() # Synchronously wait on the operation's completion. return self._dbhandle.update_ddl(ddl_statements).result() def read_snapshot(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() return self._dbhandle.snapshot() def in_transaction(self, fn, *args, **kwargs): - self.__raise_if_already_closed() + self._raise_if_already_closed() return self._dbhandle.run_in_transaction(fn, *args, **kwargs) def append_ddl_statement(self, ddl_statement): - self.__raise_if_already_closed() + self._raise_if_already_closed() self._ddl_statements.append(ddl_statement) def run_prior_DDL_statements(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() if not self._ddl_statements: return @@ -69,14 +74,16 @@ def run_prior_DDL_statements(self): return self.__handle_update_ddl(ddl_statements) def list_tables(self): - return self.run_sql_in_snapshot(""" + return self.run_sql_in_snapshot( + """ SELECT t.table_name FROM information_schema.tables AS t WHERE t.table_catalog = '' and t.table_schema = '' - """) + """ + ) def run_sql_in_snapshot(self, sql, params=None, param_types=None): # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions @@ -89,38 +96,37 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): def get_table_column_schema(self, table_name): rows = self.run_sql_in_snapshot( - '''SELECT + """SELECT COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '' AND - TABLE_NAME = @table_name''', - params={'table_name': table_name}, - param_types={'table_name': spanner.param_types.STRING}, + TABLE_NAME = @table_name""", + params={"table_name": table_name}, + param_types={"table_name": spanner.param_types.STRING}, ) column_details = {} for column_name, is_nullable, spanner_type in rows: column_details[column_name] = ColumnDetails( - null_ok=is_nullable == 'YES', - spanner_type=spanner_type, + null_ok=is_nullable == "YES", spanner_type=spanner_type ) return column_details def close(self): self.rollback() self.__dbhandle = None - self._closed = True + self.is_closed = True def commit(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() self.run_prior_DDL_statements() def rollback(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() # TODO: to be added. diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index d5f08c4e93..ebf0cb66f0 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -8,11 +8,19 @@ from google.cloud.spanner_v1 import param_types from .exceptions import ( - IntegrityError, InterfaceError, OperationalError, ProgrammingError, + IntegrityError, + InterfaceError, + OperationalError, + ProgrammingError, ) from .parse_utils import ( - STMT_DDL, STMT_INSERT, STMT_NON_UPDATING, classify_stmt, - ensure_where_clause, get_param_types, parse_insert, + STMT_DDL, + STMT_INSERT, + STMT_NON_UPDATING, + classify_stmt, + ensure_where_clause, + get_param_types, + parse_insert, sql_pyformat_args_to_spanner, ) from .utils import PeekIterator @@ -44,12 +52,9 @@ def __init__(self, connection): self._res = None self._row_count = _UNSET_COUNT self._connection = connection - self._closed = False + self._is_closed = False - # arraysize is a readable and writable property mandated - # by PEP-0249 https://www.python.org/dev/peps/pep-0249/#arraysize - # It determines the results of .fetchmany - self.arraysize = 1 + self.arraysize = 1 # the number of rows to fetch at a time with fetchmany() def execute(self, sql, args=None): """ @@ -64,7 +69,7 @@ def execute(self, sql, args=None): self._raise_if_already_closed() if not self._connection: - raise ProgrammingError('Cursor is not connected to the database') + raise ProgrammingError("Cursor is not connected to the database") self._res = None @@ -86,23 +91,22 @@ def execute(self, sql, args=None): else: self.__handle_update(sql, args or None) except (grpc_exceptions.AlreadyExists, grpc_exceptions.FailedPrecondition) as e: - raise IntegrityError(e.details if hasattr(e, 'details') else e) + raise IntegrityError(e.details if hasattr(e, "details") else e) except grpc_exceptions.InvalidArgument as e: - raise ProgrammingError(e.details if hasattr(e, 'details') else e) + raise ProgrammingError(e.details if hasattr(e, "details") else e) except grpc_exceptions.InternalServerError as e: - raise OperationalError(e.details if hasattr(e, 'details') else e) + raise OperationalError(e.details if hasattr(e, "details") else e) def __handle_update(self, sql, params): - self._connection.in_transaction( - self.__do_execute_update, - sql, params, - ) + self._connection.in_transaction(self.__do_execute_update, sql, params) def __do_execute_update(self, transaction, sql, params, param_types=None): sql = ensure_where_clause(sql) sql, params = sql_pyformat_args_to_spanner(sql, params) - res = transaction.execute_update(sql, params=params, param_types=get_param_types(params)) + res = transaction.execute_update( + sql, params=params, param_types=get_param_types(params) + ) self._itr = None if type(res) == int: self._row_count = res @@ -125,20 +129,18 @@ def __handle_insert(self, sql, params): # transaction.execute_sql(sql, params, param_types) # which invokes more RPCs and is more costly. - if parts.get('homogenous'): + if parts.get("homogenous"): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. return self._connection.in_transaction( - self.__do_execute_insert_homogenous, - parts, + self.__do_execute_insert_homogenous, parts ) else: # All the other cases that are esoteric and need # transaction.execute_sql - sql_params_list = parts.get('sql_params_list') + sql_params_list = parts.get("sql_params_list") return self._connection.in_transaction( - self.__do_execute_insert_heterogenous, - sql_params_list, + self.__do_execute_insert_heterogenous, sql_params_list ) def __do_execute_insert_heterogenous(self, transaction, sql_params_list): @@ -152,9 +154,9 @@ def __do_execute_insert_heterogenous(self, transaction, sql_params_list): def __do_execute_insert_homogenous(self, transaction, parts): # Perform an insert in one shot. - table = parts.get('table') - columns = parts.get('columns') - values = parts.get('values') + table = parts.get("table") + columns = parts.get("columns") + values = parts.get("values") return transaction.insert(table, columns, values) def __handle_DQL(self, sql, params): @@ -162,7 +164,9 @@ def __handle_DQL(self, sql, params): # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql sql, params = sql_pyformat_args_to_spanner(sql, params) - res = snapshot.execute_sql(sql, params=params, param_types=get_param_types(params)) + res = snapshot.execute_sql( + sql, params=params, param_types=get_param_types(params) + ) if type(res) == int: self._row_count = res self._itr = None @@ -216,32 +220,48 @@ def description(self): def rowcount(self): return self._row_count - def _raise_if_already_closed(self): + @property + def is_closed(self): + """The cursor close indicator. + + Returns: + bool: + True if this cursor or it's parent connection + is closed, False otherwise. """ - Raise an exception if attempting to use an already closed connection. + return self._is_closed or self._connection.is_closed + + def _raise_if_already_closed(self): + """Raise an exception if this cursor is closed. + + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + + :raises: :class:`InterfaceError` if this cursor is closed. """ - if self._closed: - raise InterfaceError('cursor already closed') + if self.is_closed: + raise InterfaceError("cursor is already closed") def close(self): self.__clear() - self._closed = True + self._is_closed = True def executemany(self, operation, seq_of_params): if not self._connection: - raise ProgrammingError('Cursor is not connected to the database') + raise ProgrammingError("Cursor is not connected to the database") for params in seq_of_params: self.execute(operation, params) def __next__(self): if self._itr is None: - raise ProgrammingError('no results to return') + raise ProgrammingError("no results to return") return next(self._itr) def __iter__(self): if self._itr is None: - raise ProgrammingError('no results to return') + raise ProgrammingError("no results to return") return self._itr def fetchone(self): @@ -289,10 +309,10 @@ def lastrowid(self): return None def setinputsizes(sizes): - raise ProgrammingError('Unimplemented') + raise ProgrammingError("Unimplemented") def setoutputsize(size, column=None): - raise ProgrammingError('Unimplemented') + raise ProgrammingError("Unimplemented") def _run_prior_DDL_statements(self): return self._connection.run_prior_DDL_statements() @@ -308,8 +328,16 @@ def get_table_column_schema(self, table_name): class Column: - def __init__(self, name, type_code, display_size=None, internal_size=None, - precision=None, scale=None, null_ok=False): + def __init__( + self, + name, + type_code, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ): self.name = name self.type_code = type_code self.display_size = display_size @@ -338,14 +366,24 @@ def __getitem__(self, index): return self.null_ok def __str__(self): - rstr = ', '.join([field for field in [ - "name='%s'" % self.name, - "type_code=%d" % self.type_code, - None if not self.display_size else "display_size=%d" % self.display_size, - None if not self.internal_size else "internal_size=%d" % self.internal_size, - None if not self.precision else "precision='%s'" % self.precision, - None if not self.scale else "scale='%s'" % self.scale, - None if not self.null_ok else "null_ok='%s'" % self.null_ok, - ] if field]) - - return 'Column(%s)' % rstr + rstr = ", ".join( + [ + field + for field in [ + "name='%s'" % self.name, + "type_code=%d" % self.type_code, + None + if not self.display_size + else "display_size=%d" % self.display_size, + None + if not self.internal_size + else "internal_size=%d" % self.internal_size, + None if not self.precision else "precision='%s'" % self.precision, + None if not self.scale else "scale='%s'" % self.scale, + None if not self.null_ok else "null_ok='%s'" % self.null_ok, + ] + if field + ] + ) + + return "Column(%s)" % rstr diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..99014ea254 --- /dev/null +++ b/tests/spanner_dbapi/test_cursor.py @@ -0,0 +1,54 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cursor() class unit tests.""" + +import unittest +from unittest import mock + + +class TestCursor(unittest.TestCase): + def test_close(self): + from spanner_dbapi import connect + from spanner_dbapi.exceptions import InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_connection_closed(self): + from spanner_dbapi import connect + from spanner_dbapi.exceptions import InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + connection.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + +if __name__ == "__main__": + unittest.main() From c87d556714302c6a82c617265159a655c4c87e18 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 25 Aug 2020 13:22:40 +0300 Subject: [PATCH 05/12] update formatters configs --- setup.cfg | 4 +++- tox.ini | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index f9e8dff043..43c26175ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,8 +2,10 @@ max-line-length = 119 [isort] +use_parentheses=True combine_as_imports = true default_section = THIRDPARTY include_trailing_comma = true +force_grid_wrap=0 line_length = 79 -multi_line_output = 5 +multi_line_output = 3 \ No newline at end of file diff --git a/tox.ini b/tox.ini index 94d9bbe241..148efd99a9 100644 --- a/tox.ini +++ b/tox.ini @@ -21,4 +21,4 @@ deps = isort commands = flake8 - isort --recursive --check-only --diff + isort --recursive --check-only --diff . From c517882c6c23bb2536b1e3c90121dd29c8ccda5e Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 25 Aug 2020 13:58:10 +0300 Subject: [PATCH 06/12] error type and nits --- spanner_dbapi/__init__.py | 6 +++--- tests/spanner_dbapi/test_connect.py | 14 +++----------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/spanner_dbapi/__init__.py b/spanner_dbapi/__init__.py index cf88da598f..58037106ca 100644 --- a/spanner_dbapi/__init__.py +++ b/spanner_dbapi/__init__.py @@ -72,7 +72,7 @@ def connect(instance_id, database_id, project=None, credentials=None, user_agent :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` :returns: Connection object associated with the given Cloud Spanner resource. - :raises: :class:`ProgrammingError` in case of given instance/database + :raises: :class:`ValueError` in case of given instance/database doesn't exist. """ client = spanner_v1.Client( @@ -83,11 +83,11 @@ def connect(instance_id, database_id, project=None, credentials=None, user_agent instance = client.instance(instance_id) if not instance.exists(): - raise ProgrammingError("instance '%s' does not exist." % instance_id) + raise ValueError("instance '%s' does not exist." % instance_id) database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool()) if not database.exists(): - raise ProgrammingError("database '%s' does not exist." % database_id) + raise ValueError("database '%s' does not exist." % database_id) return Connection(database) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index a57d3ca0ff..5314db64ac 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -21,7 +21,7 @@ class _CredentialsWithScopes( return mock.Mock(spec=_CredentialsWithScopes) -class Testconnect(unittest.TestCase): +class Test_connect(unittest.TestCase): def _callFUT(self, *args, **kw): from spanner_dbapi import connect @@ -53,20 +53,16 @@ def test_connect(self): ) def test_instance_not_found(self): - from spanner_dbapi.exceptions import ProgrammingError - with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=False ) as exists_mock: - with self.assertRaises(ProgrammingError): + with self.assertRaises(ValueError): self._callFUT("test-instance", "test-database") exists_mock.assert_called_once() def test_database_not_found(self): - from spanner_dbapi.exceptions import ProgrammingError - with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): @@ -74,7 +70,7 @@ def test_database_not_found(self): "google.cloud.spanner_v1.database.Database.exists", return_value=False ) as exists_mock: - with self.assertRaises(ProgrammingError): + with self.assertRaises(ValueError): self._callFUT("test-instance", "test-database") exists_mock.assert_called_once() @@ -109,7 +105,3 @@ def test_connect_database_id(self): database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) self.assertIsInstance(connection, Connection) - - -if __name__ == "__main__": - unittest.main() From 0cb7f279788dabbb6ff199c06a4316cab40b03c6 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 26 Aug 2020 11:21:53 +0300 Subject: [PATCH 07/12] fix imports and erase FUT helper --- tests/spanner_dbapi/test_connect.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index 5314db64ac..15a86269c8 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -9,10 +9,12 @@ import unittest from unittest import mock +import google.auth.credentials +from google.api_core.gapic_v1.client_info import ClientInfo +from spanner_dbapi import connect, Connection -def _make_credentials(): - import google.auth.credentials +def _make_credentials(): class _CredentialsWithScopes( google.auth.credentials.Credentials, google.auth.credentials.Scoped ): @@ -22,15 +24,7 @@ class _CredentialsWithScopes( class Test_connect(unittest.TestCase): - def _callFUT(self, *args, **kw): - from spanner_dbapi import connect - - return connect(*args, **kw) - def test_connect(self): - from google.api_core.gapic_v1.client_info import ClientInfo - from spanner_dbapi.connection import Connection - PROJECT = "test-project" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() @@ -41,7 +35,7 @@ def test_connect(self): "spanner_dbapi.google_client_info", return_value=CLIENT_INFO ) as client_info_mock: - connection = self._callFUT( + connection = connect( "test-instance", "test-database", PROJECT, CREDENTIALS, USER_AGENT ) @@ -58,7 +52,7 @@ def test_instance_not_found(self): ) as exists_mock: with self.assertRaises(ValueError): - self._callFUT("test-instance", "test-database") + connect("test-instance", "test-database") exists_mock.assert_called_once() @@ -71,27 +65,23 @@ def test_database_not_found(self): ) as exists_mock: with self.assertRaises(ValueError): - self._callFUT("test-instance", "test-database") + connect("test-instance", "test-database") exists_mock.assert_called_once() def test_connect_instance_id(self): - from spanner_dbapi.connection import Connection - INSTANCE = "test-instance" with mock.patch( "google.cloud.spanner_v1.client.Client.instance" ) as instance_mock: - connection = self._callFUT(INSTANCE, "test-database") + connection = connect(INSTANCE, "test-database") instance_mock.assert_called_once_with(INSTANCE) self.assertIsInstance(connection, Connection) def test_connect_database_id(self): - from spanner_dbapi.connection import Connection - DATABASE = "test-database" with mock.patch( @@ -100,7 +90,7 @@ def test_connect_database_id(self): with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): - connection = self._callFUT("test-instance", DATABASE) + connection = connect("test-instance", DATABASE) database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) From f049e02fafb5c5f87734e1cc0283c21497657fc3 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 26 Aug 2020 11:41:32 +0300 Subject: [PATCH 08/12] fix assert_called_once AttributeError in Python 3.5 --- tests/spanner_dbapi/test_connect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index 15a86269c8..64eaf8cd7d 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -54,7 +54,7 @@ def test_instance_not_found(self): with self.assertRaises(ValueError): connect("test-instance", "test-database") - exists_mock.assert_called_once() + exists_mock.assert_called_once_with() def test_database_not_found(self): with mock.patch( @@ -67,7 +67,7 @@ def test_database_not_found(self): with self.assertRaises(ValueError): connect("test-instance", "test-database") - exists_mock.assert_called_once() + exists_mock.assert_called_once_with() def test_connect_instance_id(self): INSTANCE = "test-instance" From dc3891342ec67e5a950392f818e2408ae9ce3c49 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 26 Aug 2020 13:20:41 +0300 Subject: [PATCH 09/12] fix imports --- tests/spanner_dbapi/test_cursor.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 99014ea254..69d41f4217 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -9,12 +9,11 @@ import unittest from unittest import mock +from spanner_dbapi import connect, InterfaceError + class TestCursor(unittest.TestCase): def test_close(self): - from spanner_dbapi import connect - from spanner_dbapi.exceptions import InterfaceError - with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): @@ -31,9 +30,6 @@ def test_close(self): cursor.execute("SELECT * FROM database") def test_connection_closed(self): - from spanner_dbapi import connect - from spanner_dbapi.exceptions import InterfaceError - with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True ): @@ -48,7 +44,3 @@ def test_connection_closed(self): self.assertTrue(cursor.is_closed) with self.assertRaises(InterfaceError): cursor.execute("SELECT * FROM database") - - -if __name__ == "__main__": - unittest.main() From 51afccd0de20cb8a0f4b425b4c9f2ef0ed095464 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 26 Aug 2020 19:47:10 +0300 Subject: [PATCH 10/12] fix backends test arg names --- django_spanner/base.py | 116 +++++++++++++++++++++-------------------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/django_spanner/base.py b/django_spanner/base.py index 70f5b1ba39..f5d2c835d5 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -18,62 +18,62 @@ class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'spanner' - display_name = 'Cloud Spanner' + vendor = "spanner" + display_name = "Cloud Spanner" # Mapping of Field objects to their column types. # https://cloud.google.com/spanner/docs/data-types#date-type data_types = { - 'AutoField': 'INT64', - 'BigAutoField': 'INT64', - 'BinaryField': 'BYTES(MAX)', - 'BooleanField': 'BOOL', - 'CharField': 'STRING(%(max_length)s)', - 'DateField': 'DATE', - 'DateTimeField': 'TIMESTAMP', - 'DecimalField': 'FLOAT64', - 'DurationField': 'INT64', - 'EmailField': 'STRING(%(max_length)s)', - 'FileField': 'STRING(%(max_length)s)', - 'FilePathField': 'STRING(%(max_length)s)', - 'FloatField': 'FLOAT64', - 'IntegerField': 'INT64', - 'BigIntegerField': 'INT64', - 'IPAddressField': 'STRING(15)', - 'GenericIPAddressField': 'STRING(39)', - 'NullBooleanField': 'BOOL', - 'OneToOneField': 'INT64', - 'PositiveIntegerField': 'INT64', - 'PositiveSmallIntegerField': 'INT64', - 'SlugField': 'STRING(%(max_length)s)', - 'SmallAutoField': 'INT64', - 'SmallIntegerField': 'INT64', - 'TextField': 'STRING(MAX)', - 'TimeField': 'TIMESTAMP', - 'UUIDField': 'STRING(32)', + "AutoField": "INT64", + "BigAutoField": "INT64", + "BinaryField": "BYTES(MAX)", + "BooleanField": "BOOL", + "CharField": "STRING(%(max_length)s)", + "DateField": "DATE", + "DateTimeField": "TIMESTAMP", + "DecimalField": "FLOAT64", + "DurationField": "INT64", + "EmailField": "STRING(%(max_length)s)", + "FileField": "STRING(%(max_length)s)", + "FilePathField": "STRING(%(max_length)s)", + "FloatField": "FLOAT64", + "IntegerField": "INT64", + "BigIntegerField": "INT64", + "IPAddressField": "STRING(15)", + "GenericIPAddressField": "STRING(39)", + "NullBooleanField": "BOOL", + "OneToOneField": "INT64", + "PositiveIntegerField": "INT64", + "PositiveSmallIntegerField": "INT64", + "SlugField": "STRING(%(max_length)s)", + "SmallAutoField": "INT64", + "SmallIntegerField": "INT64", + "TextField": "STRING(MAX)", + "TimeField": "TIMESTAMP", + "UUIDField": "STRING(32)", } operators = { - 'exact': '= %s', - 'iexact': 'REGEXP_CONTAINS(%s, %%%%s)', + "exact": "= %s", + "iexact": "REGEXP_CONTAINS(%s, %%%%s)", # contains uses REGEXP_CONTAINS instead of LIKE to allow # DatabaseOperations.prep_for_like_query() to do regular expression # escaping. prep_for_like_query() is called for all the lookups that # use REGEXP_CONTAINS except regex/iregex (see # django.db.models.lookups.PatternLookup). - 'contains': 'REGEXP_CONTAINS(%s, %%%%s)', - 'icontains': 'REGEXP_CONTAINS(%s, %%%%s)', - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', + "contains": "REGEXP_CONTAINS(%s, %%%%s)", + "icontains": "REGEXP_CONTAINS(%s, %%%%s)", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", # Using REGEXP_CONTAINS instead of STARTS_WITH and ENDS_WITH for the # same reasoning as described above for 'contains'. - 'startswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'endswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'istartswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'iendswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'regex': 'REGEXP_CONTAINS(%s, %%%%s)', - 'iregex': 'REGEXP_CONTAINS(%s, %%%%s)', + "startswith": "REGEXP_CONTAINS(%s, %%%%s)", + "endswith": "REGEXP_CONTAINS(%s, %%%%s)", + "istartswith": "REGEXP_CONTAINS(%s, %%%%s)", + "iendswith": "REGEXP_CONTAINS(%s, %%%%s)", + "regex": "REGEXP_CONTAINS(%s, %%%%s)", + "iregex": "REGEXP_CONTAINS(%s, %%%%s)", } # pattern_esc is used to generate SQL pattern lookup clauses when the @@ -81,16 +81,18 @@ class DatabaseWrapper(BaseDatabaseWrapper): # expression or the result of a bilateral transformation). In those cases, # special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be # escaped on database side. - pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")' + pattern_esc = ( + r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")' + ) # These are all no-ops in favor of using REGEXP_CONTAINS in the customized # lookups. pattern_ops = { - 'contains': '', - 'icontains': '', - 'startswith': '', - 'istartswith': '', - 'endswith': '', - 'iendswith': '', + "contains": "", + "icontains": "", + "startswith": "", + "istartswith": "", + "endswith": "", + "iendswith": "", } Database = Database @@ -104,7 +106,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): @property def instance(self): - return spanner.Client().instance(self.settings_dict['INSTANCE']) + return spanner.Client().instance(self.settings_dict["INSTANCE"]) @property def _nodb_connection(self): @@ -112,11 +114,11 @@ def _nodb_connection(self): def get_connection_params(self): return { - 'project': self.settings_dict['PROJECT'], - 'instance': self.settings_dict['INSTANCE'], - 'database': self.settings_dict['NAME'], - 'user_agent': 'django_spanner/0.0.1', - **self.settings_dict['OPTIONS'], + "project": self.settings_dict["PROJECT"], + "instance_id": self.settings_dict["INSTANCE"], + "database_id": self.settings_dict["NAME"], + "user_agent": "django_spanner/0.0.1", + **self.settings_dict["OPTIONS"], } def get_new_connection(self, conn_params): @@ -137,7 +139,7 @@ def is_usable(self): return False try: # Use a cursor directly, bypassing Django's utilities. - self.connection.cursor().execute('SELECT 1') + self.connection.cursor().execute("SELECT 1") except Database.Error: return False else: From e57da345461933dfd89eb97a4a0aa003c28b4e3c Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 28 Aug 2020 10:49:52 +0300 Subject: [PATCH 11/12] style fixes --- spanner_dbapi/connection.py | 22 +++++++------ spanner_dbapi/cursor.py | 43 ++++++++++++++++++-------- tests/spanner_dbapi/test_connection.py | 30 ++++++++++++++++++ tests/spanner_dbapi/test_cursor.py | 4 +++ 4 files changed, 77 insertions(+), 22 deletions(-) create mode 100644 tests/spanner_dbapi/test_connection.py diff --git a/spanner_dbapi/connection.py b/spanner_dbapi/connection.py index 0ae8c84d27..b0279de063 100644 --- a/spanner_dbapi/connection.py +++ b/spanner_dbapi/connection.py @@ -22,11 +22,11 @@ def __init__(self, db_handle): self.is_closed = False def cursor(self): - self._raise_if_already_closed() + self._raise_if_closed() return Cursor(self) - def _raise_if_already_closed(self): + def _raise_if_closed(self): """Raise an exception if this connection is closed. Helper to check the connection state before @@ -46,24 +46,24 @@ def __handle_update_ddl(self, ddl_statements): Returns: google.api_core.operation.Operation.result() """ - self._raise_if_already_closed() + self._raise_if_closed() # Synchronously wait on the operation's completion. return self._dbhandle.update_ddl(ddl_statements).result() def read_snapshot(self): - self._raise_if_already_closed() + self._raise_if_closed() return self._dbhandle.snapshot() def in_transaction(self, fn, *args, **kwargs): - self._raise_if_already_closed() + self._raise_if_closed() return self._dbhandle.run_in_transaction(fn, *args, **kwargs) def append_ddl_statement(self, ddl_statement): - self._raise_if_already_closed() + self._raise_if_closed() self._ddl_statements.append(ddl_statement) def run_prior_DDL_statements(self): - self._raise_if_already_closed() + self._raise_if_closed() if not self._ddl_statements: return @@ -116,17 +116,21 @@ def get_table_column_schema(self, table_name): return column_details def close(self): + """Close this connection. + + The connection will be unusable from this point forward. + """ self.rollback() self.__dbhandle = None self.is_closed = True def commit(self): - self._raise_if_already_closed() + self._raise_if_closed() self.run_prior_DDL_statements() def rollback(self): - self._raise_if_already_closed() + self._raise_if_closed() # TODO: to be added. diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index ebf0cb66f0..2b0bc9cce7 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -4,7 +4,14 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import google.api_core.exceptions as grpc_exceptions +"""Database cursor API.""" + +from google.api_core.exceptions import ( + AlreadyExists, + FailedPrecondition, + InternalServerError, + InvalidArgument, +) from google.cloud.spanner_v1 import param_types from .exceptions import ( @@ -47,6 +54,13 @@ class Cursor: + """ + Database cursor to manage the context of a fetch operation. + + :type connection: :class:`spanner_dbapi.connection.Connection` + :param connection: Parent connection object for this Cursor. + """ + def __init__(self, connection): self._itr = None self._res = None @@ -66,7 +80,7 @@ def execute(self, sql, args=None): Returns: None """ - self._raise_if_already_closed() + self._raise_if_closed() if not self._connection: raise ProgrammingError("Cursor is not connected to the database") @@ -90,11 +104,11 @@ def execute(self, sql, args=None): self.__handle_insert(sql, args or None) else: self.__handle_update(sql, args or None) - except (grpc_exceptions.AlreadyExists, grpc_exceptions.FailedPrecondition) as e: + except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) - except grpc_exceptions.InvalidArgument as e: + except InvalidArgument as e: raise ProgrammingError(e.details if hasattr(e, "details") else e) - except grpc_exceptions.InternalServerError as e: + except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) def __handle_update(self, sql, params): @@ -224,14 +238,13 @@ def rowcount(self): def is_closed(self): """The cursor close indicator. - Returns: - bool: - True if this cursor or it's parent connection - is closed, False otherwise. + :rtype: :class:`bool` + :returns: True if this cursor or it's parent connection is closed, False + otherwise. """ return self._is_closed or self._connection.is_closed - def _raise_if_already_closed(self): + def _raise_if_closed(self): """Raise an exception if this cursor is closed. Helper to check this cursor's state before running a @@ -244,6 +257,10 @@ def _raise_if_already_closed(self): raise InterfaceError("cursor is already closed") def close(self): + """Close this cursor. + + The cursor will be unusable from this point forward. + """ self.__clear() self._is_closed = True @@ -265,7 +282,7 @@ def __iter__(self): return self._itr def fetchone(self): - self._raise_if_already_closed() + self._raise_if_closed() try: return next(self) @@ -273,7 +290,7 @@ def fetchone(self): return None def fetchall(self): - self._raise_if_already_closed() + self._raise_if_closed() return list(self.__iter__()) @@ -290,7 +307,7 @@ def fetchmany(self, size=None): Error if the previous call to .execute*() did not produce any result set or if no call was issued yet. """ - self._raise_if_already_closed() + self._raise_if_closed() if size is None: size = self.arraysize diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py new file mode 100644 index 0000000000..63d4f3c89c --- /dev/null +++ b/tests/spanner_dbapi/test_connection.py @@ -0,0 +1,30 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Connection() class unit tests.""" + +import unittest +from unittest import mock + +from spanner_dbapi import connect, InterfaceError + + +class TestConnection(unittest.TestCase): + def test_close(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) + + with self.assertRaises(InterfaceError): + connection.cursor() diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 69d41f4217..4fa44a476d 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -23,6 +23,8 @@ def test_close(self): connection = connect("test-instance", "test-database") cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + cursor.close() self.assertTrue(cursor.is_closed) @@ -39,6 +41,8 @@ def test_connection_closed(self): connection = connect("test-instance", "test-database") cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + connection.close() self.assertTrue(cursor.is_closed) From c3ce46e1e6df05d209bbebe4cc45c40aad142321 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 1 Sep 2020 10:15:01 +0300 Subject: [PATCH 12/12] fix lint issues --- spanner_dbapi/cursor.py | 11 ++++++++--- tests/spanner_dbapi/test_connection.py | 6 ++++-- tests/spanner_dbapi/test_cursor.py | 12 ++++++++---- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index 1c3b57dcc9..10e5184ed2 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -68,7 +68,8 @@ def __init__(self, connection): self._connection = connection self._is_closed = False - self.arraysize = 1 # the number of rows to fetch at a time with fetchmany() + # the number of rows to fetch at a time with fetchmany() + self.arraysize = 1 def execute(self, sql, args=None): """ @@ -397,9 +398,13 @@ def __str__(self): None if not self.internal_size else "internal_size=%d" % self.internal_size, - None if not self.precision else "precision='%s'" % self.precision, + None + if not self.precision + else "precision='%s'" % self.precision, None if not self.scale else "scale='%s'" % self.scale, - None if not self.null_ok else "null_ok='%s'" % self.null_ok, + None + if not self.null_ok + else "null_ok='%s'" % self.null_ok, ] if field ] diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index 63d4f3c89c..ab72f799df 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -15,10 +15,12 @@ class TestConnection(unittest.TestCase): def test_close(self): with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, ): connection = connect("test-instance", "test-database") diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 4fa44a476d..6bf6bb27e4 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -15,10 +15,12 @@ class TestCursor(unittest.TestCase): def test_close(self): with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, ): connection = connect("test-instance", "test-database") @@ -33,10 +35,12 @@ def test_close(self): def test_connection_closed(self): with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, ): with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, ): connection = connect("test-instance", "test-database")