Skip to content

Commit

Permalink
feat: refactor connect() function, cover it with unit tests (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gurov Ilya authored and c24t committed Sep 15, 2020
1 parent 96f2223 commit 4fedcf1
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 113 deletions.
116 changes: 59 additions & 57 deletions django_spanner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,79 +18,81 @@


class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'spanner'
display_name = 'Cloud Spanner'
vendor = "spanner"
display_name = "Cloud Spanner"

# Mapping of Field objects to their column types.
# https://cloud.google.com/spanner/docs/data-types#date-type
data_types = {
'AutoField': 'INT64',
'BigAutoField': 'INT64',
'BinaryField': 'BYTES(MAX)',
'BooleanField': 'BOOL',
'CharField': 'STRING(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'FLOAT64',
'DurationField': 'INT64',
'EmailField': 'STRING(%(max_length)s)',
'FileField': 'STRING(%(max_length)s)',
'FilePathField': 'STRING(%(max_length)s)',
'FloatField': 'FLOAT64',
'IntegerField': 'INT64',
'BigIntegerField': 'INT64',
'IPAddressField': 'STRING(15)',
'GenericIPAddressField': 'STRING(39)',
'NullBooleanField': 'BOOL',
'OneToOneField': 'INT64',
'PositiveIntegerField': 'INT64',
'PositiveSmallIntegerField': 'INT64',
'SlugField': 'STRING(%(max_length)s)',
'SmallAutoField': 'INT64',
'SmallIntegerField': 'INT64',
'TextField': 'STRING(MAX)',
'TimeField': 'TIMESTAMP',
'UUIDField': 'STRING(32)',
"AutoField": "INT64",
"BigAutoField": "INT64",
"BinaryField": "BYTES(MAX)",
"BooleanField": "BOOL",
"CharField": "STRING(%(max_length)s)",
"DateField": "DATE",
"DateTimeField": "TIMESTAMP",
"DecimalField": "FLOAT64",
"DurationField": "INT64",
"EmailField": "STRING(%(max_length)s)",
"FileField": "STRING(%(max_length)s)",
"FilePathField": "STRING(%(max_length)s)",
"FloatField": "FLOAT64",
"IntegerField": "INT64",
"BigIntegerField": "INT64",
"IPAddressField": "STRING(15)",
"GenericIPAddressField": "STRING(39)",
"NullBooleanField": "BOOL",
"OneToOneField": "INT64",
"PositiveIntegerField": "INT64",
"PositiveSmallIntegerField": "INT64",
"SlugField": "STRING(%(max_length)s)",
"SmallAutoField": "INT64",
"SmallIntegerField": "INT64",
"TextField": "STRING(MAX)",
"TimeField": "TIMESTAMP",
"UUIDField": "STRING(32)",
}
operators = {
'exact': '= %s',
'iexact': 'REGEXP_CONTAINS(%s, %%%%s)',
"exact": "= %s",
"iexact": "REGEXP_CONTAINS(%s, %%%%s)",
# contains uses REGEXP_CONTAINS instead of LIKE to allow
# DatabaseOperations.prep_for_like_query() to do regular expression
# escaping. prep_for_like_query() is called for all the lookups that
# use REGEXP_CONTAINS except regex/iregex (see
# django.db.models.lookups.PatternLookup).
'contains': 'REGEXP_CONTAINS(%s, %%%%s)',
'icontains': 'REGEXP_CONTAINS(%s, %%%%s)',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
"contains": "REGEXP_CONTAINS(%s, %%%%s)",
"icontains": "REGEXP_CONTAINS(%s, %%%%s)",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
# Using REGEXP_CONTAINS instead of STARTS_WITH and ENDS_WITH for the
# same reasoning as described above for 'contains'.
'startswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'endswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'istartswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'iendswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'regex': 'REGEXP_CONTAINS(%s, %%%%s)',
'iregex': 'REGEXP_CONTAINS(%s, %%%%s)',
"startswith": "REGEXP_CONTAINS(%s, %%%%s)",
"endswith": "REGEXP_CONTAINS(%s, %%%%s)",
"istartswith": "REGEXP_CONTAINS(%s, %%%%s)",
"iendswith": "REGEXP_CONTAINS(%s, %%%%s)",
"regex": "REGEXP_CONTAINS(%s, %%%%s)",
"iregex": "REGEXP_CONTAINS(%s, %%%%s)",
}

# pattern_esc is used to generate SQL pattern lookup clauses when the
# right-hand side of the lookup isn't a raw string (it might be an
# expression or the result of a bilateral transformation). In those cases,
# special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be
# escaped on database side.
pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")'
pattern_esc = (
r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")'
)
# These are all no-ops in favor of using REGEXP_CONTAINS in the customized
# lookups.
pattern_ops = {
'contains': '',
'icontains': '',
'startswith': '',
'istartswith': '',
'endswith': '',
'iendswith': '',
"contains": "",
"icontains": "",
"startswith": "",
"istartswith": "",
"endswith": "",
"iendswith": "",
}

Database = Database
Expand All @@ -104,19 +106,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):

@property
def instance(self):
return spanner.Client().instance(self.settings_dict['INSTANCE'])
return spanner.Client().instance(self.settings_dict["INSTANCE"])

@property
def _nodb_connection(self):
raise NotImplementedError('Spanner does not have a "no db" connection.')

def get_connection_params(self):
return {
'project': self.settings_dict['PROJECT'],
'instance': self.settings_dict['INSTANCE'],
'database': self.settings_dict['NAME'],
'user_agent': 'django_spanner/0.0.1',
**self.settings_dict['OPTIONS'],
"project": self.settings_dict["PROJECT"],
"instance_id": self.settings_dict["INSTANCE"],
"database_id": self.settings_dict["NAME"],
"user_agent": "django_spanner/0.0.1",
**self.settings_dict["OPTIONS"],
}

def get_new_connection(self, conn_params):
Expand All @@ -137,7 +139,7 @@ def is_usable(self):
return False
try:
# Use a cursor directly, bypassing Django's utilities.
self.connection.cursor().execute('SELECT 1')
self.connection.cursor().execute("SELECT 1")
except Database.Error:
return False
else:
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
max-line-length = 119

[isort]
use_parentheses=True
combine_as_imports = true
default_section = THIRDPARTY
include_trailing_comma = true
force_grid_wrap=0
line_length = 79
multi_line_output = 5
multi_line_output = 3
149 changes: 94 additions & 55 deletions spanner_dbapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,83 +4,122 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from google.cloud import spanner_v1 as spanner
"""Connection-based DB API for Cloud Spanner."""

from google.cloud import spanner_v1

from .connection import Connection
# These need to be included in the top-level package for PEP-0249 DB API v2.
from .exceptions import (
DatabaseError, DataError, Error, IntegrityError, InterfaceError,
InternalError, NotSupportedError, OperationalError, ProgrammingError,
DatabaseError,
DataError,
Error,
IntegrityError,
InterfaceError,
InternalError,
NotSupportedError,
OperationalError,
ProgrammingError,
Warning,
)
from .parse_utils import get_param_types
from .types import (
BINARY, DATETIME, NUMBER, ROWID, STRING, Binary, Date, DateFromTicks, Time,
TimeFromTicks, Timestamp, TimestampFromTicks,
BINARY,
DATETIME,
NUMBER,
ROWID,
STRING,
Binary,
Date,
DateFromTicks,
Time,
TimeFromTicks,
Timestamp,
TimestampFromTicks,
)
from .version import google_client_info

# Globals that MUST be defined ###
apilevel = "2.0" # Implements the Python Database API specification 2.0 version.
# We accept arguments in the format '%s' aka ANSI C print codes.
# as per https://www.python.org/dev/peps/pep-0249/#paramstyle
paramstyle = 'format'
# Threads may share the module but not connections. This is a paranoid threadsafety level,
# but it is necessary for starters to use when debugging failures. Eventually once transactions
# are working properly, we'll update the threadsafety level.
apilevel = "2.0" # supports DP-API 2.0 level.
paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s.

# Threads may share the module, but not connections. This is a paranoid threadsafety
# level, but it is necessary for starters to use when debugging failures.
# Eventually once transactions are working properly, we'll update the
# threadsafety level.
threadsafety = 1


def connect(project=None, instance=None, database=None, credentials_uri=None, user_agent=None):
def connect(instance_id, database_id, project=None, credentials=None, user_agent=None):
"""
Connect to Cloud Spanner.
Create a connection to Cloud Spanner database.
Args:
project: The id of a project that already exists.
instance: The id of an instance that already exists.
database: The name of a database that already exists.
credentials_uri: An optional string specifying where to retrieve the service
account JSON for the credentials to connect to Cloud Spanner.
:type instance_id: :class:`str`
:param instance_id: ID of the instance to connect to.
Returns:
The Connection object associated to the Cloud Spanner instance.
:type database_id: :class:`str`
:param database_id: The name of the database to connect to.
Raises:
Error if it encounters any unexpected inputs.
"""
if not project:
raise Error("'project' is required.")
if not instance:
raise Error("'instance' is required.")
if not database:
raise Error("'database' is required.")
:type project: :class:`str`
:param project: (Optional) The ID of the project which owns the
instances, tables and data. If not provided, will
attempt to determine from the environment.
client_kwargs = {
'project': project,
'client_info': google_client_info(user_agent),
}
if credentials_uri:
client = spanner.Client.from_service_account_json(credentials_uri, **client_kwargs)
else:
client = spanner.Client(**client_kwargs)
:type credentials: :class:`google.auth.credentials.Credentials`
:param credentials: (Optional) The authorization credentials to attach to requests.
These credentials identify this application to the service.
If none are specified, the client will attempt to ascertain
the credentials from the environment.
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Cloud Spanner resource.
:raises: :class:`ValueError` in case of given instance/database
doesn't exist.
"""
client = spanner_v1.Client(
project=project,
credentials=credentials,
client_info=google_client_info(user_agent),
)

client_instance = client.instance(instance)
if not client_instance.exists():
raise ProgrammingError("instance '%s' does not exist." % instance)
instance = client.instance(instance_id)
if not instance.exists():
raise ValueError("instance '%s' does not exist." % instance_id)

db = client_instance.database(database, pool=spanner.pool.BurstyPool())
if not db.exists():
raise ProgrammingError("database '%s' does not exist." % database)
database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool())
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

return Connection(db)
return Connection(database)


__all__ = [
'DatabaseError', 'DataError', 'Error', 'IntegrityError', 'InterfaceError',
'InternalError', 'NotSupportedError', 'OperationalError', 'ProgrammingError',
'Warning', 'DEFAULT_USER_AGENT', 'apilevel', 'connect', 'paramstyle', 'threadsafety',
'get_param_types',
'Binary', 'Date', 'DateFromTicks', 'Time', 'TimeFromTicks', 'Timestamp',
'TimestampFromTicks',
'BINARY', 'STRING', 'NUMBER', 'DATETIME', 'ROWID', 'TimestampStr',
"DatabaseError",
"DataError",
"Error",
"IntegrityError",
"InterfaceError",
"InternalError",
"NotSupportedError",
"OperationalError",
"ProgrammingError",
"Warning",
"DEFAULT_USER_AGENT",
"apilevel",
"connect",
"paramstyle",
"threadsafety",
"get_param_types",
"Binary",
"Date",
"DateFromTicks",
"Time",
"TimeFromTicks",
"Timestamp",
"TimestampFromTicks",
"BINARY",
"STRING",
"NUMBER",
"DATETIME",
"ROWID",
"TimestampStr",
]
Loading

0 comments on commit 4fedcf1

Please sign in to comment.