From d652bf2f95fbfc91dcab21dcf01aea979e99116e Mon Sep 17 00:00:00 2001 From: Vikash Singh <3116482+vi3k6i5@users.noreply.github.com> Date: Sun, 9 May 2021 12:28:36 +0530 Subject: [PATCH] feat: add decimal/numeric support --- django_spanner/base.py | 2 +- django_spanner/introspection.py | 1 + django_spanner/lookups.py | 7 +- django_spanner/operations.py | 26 +---- noxfile.py | 57 ++++++++++- tests/system/conftest.py | 19 ++++ tests/system/django_spanner/__init__.py | 0 tests/system/django_spanner/models.py | 23 +++++ tests/system/django_spanner/test_decimal.py | 108 ++++++++++++++++++++ tests/system/settings.py | 46 +++++++++ 10 files changed, 256 insertions(+), 33 deletions(-) create mode 100644 tests/system/conftest.py create mode 100644 tests/system/django_spanner/__init__.py create mode 100644 tests/system/django_spanner/models.py create mode 100644 tests/system/django_spanner/test_decimal.py create mode 100644 tests/system/settings.py diff --git a/django_spanner/base.py b/django_spanner/base.py index 9b0824a25c..3c209a0c0a 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -34,7 +34,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): "CharField": "STRING(%(max_length)s)", "DateField": "DATE", "DateTimeField": "TIMESTAMP", - "DecimalField": "FLOAT64", + "DecimalField": "NUMERIC", "DurationField": "INT64", "EmailField": "STRING(%(max_length)s)", "FileField": "STRING(%(max_length)s)", diff --git a/django_spanner/introspection.py b/django_spanner/introspection.py index 9cefd0687f..1cbb50a28a 100644 --- a/django_spanner/introspection.py +++ b/django_spanner/introspection.py @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): TypeCode.INT64: "IntegerField", TypeCode.STRING: "CharField", TypeCode.TIMESTAMP: "DateTimeField", + TypeCode.NUMERIC: "DecimalField", } def get_field_type(self, data_type, description): diff --git a/django_spanner/lookups.py b/django_spanner/lookups.py index cad536c914..bfe9538c56 100644 --- a/django_spanner/lookups.py +++ b/django_spanner/lookups.py @@ -233,13 +233,8 @@ def cast_param_to_float(self, compiler, connection): """ sql, params = self.as_sql(compiler, connection) if params: - # Cast to DecimaField lookup values to float because - # google.cloud.spanner_v1._helpers._make_value_pb() doesn't serialize - # decimal.Decimal. - if isinstance(self.lhs.output_field, DecimalField): - params[0] = float(params[0]) # Cast remote field lookups that must be integer but come in as string. - elif hasattr(self.lhs.output_field, "get_path_info"): + if hasattr(self.lhs.output_field, "get_path_info"): for i, field in enumerate( self.lhs.output_field.get_path_info()[-1].target_fields ): diff --git a/django_spanner/operations.py b/django_spanner/operations.py index e3ff7471ec..4038f9cd23 100644 --- a/django_spanner/operations.py +++ b/django_spanner/operations.py @@ -203,12 +203,12 @@ def adapt_decimalfield_value( :param decimal_places: (Optional) The number of decimal places to store with the number. - :rtype: float + :rtype: Decimal :returns: Formatted value. """ if value is None: return None - return float(value) + return Decimal(value) def adapt_timefield_value(self, value): """ @@ -244,8 +244,6 @@ def get_db_converters(self, expression): internal_type = expression.output_field.get_internal_type() if internal_type == "DateTimeField": converters.append(self.convert_datetimefield_value) - elif internal_type == "DecimalField": - converters.append(self.convert_decimalfield_value) elif internal_type == "TimeField": converters.append(self.convert_timefield_value) elif internal_type == "BinaryField": @@ -311,26 +309,6 @@ def convert_datetimefield_value(self, value, expression, connection): else dt ) - def convert_decimalfield_value(self, value, expression, connection): - """Convert Spanner DecimalField value for Django. - - :type value: float - :param value: A decimal field. - - :type expression: :class:`django.db.models.expressions.BaseExpression` - :param expression: A query expression. - - :type connection: :class:`~google.cloud.cpanner_dbapi.connection.Connection` - :param connection: Reference to a Spanner database connection. - - :rtype: :class:`Decimal` - :returns: A converted decimal field. - """ - if value is None: - return value - # Cloud Spanner returns a float. - return Decimal(str(value)) - def convert_timefield_value(self, value, expression, connection): """Convert Spanner TimeField value for Django. diff --git a/noxfile.py b/noxfile.py index a19bbc4360..883cb51149 100644 --- a/noxfile.py +++ b/noxfile.py @@ -10,6 +10,7 @@ from __future__ import absolute_import import os +import pathlib import shutil import nox @@ -25,7 +26,9 @@ DEFAULT_PYTHON_VERSION = "3.8" SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] -UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"] +UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @nox.session(python=DEFAULT_PYTHON_VERSION) @@ -81,7 +84,7 @@ def default(session): "--cov-report=", "--cov-fail-under=20", os.path.join("tests", "unit"), - *session.posargs + *session.posargs, ) @@ -91,6 +94,56 @@ def unit(session): default(session) +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def system(session): + """Run the system test suite.""" + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + system_test_path = os.path.join("tests", "system.py") + system_test_folder_path = os.path.join("tests", "system") + + # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. + if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": + session.skip("RUN_SYSTEM_TESTS is set to false, skipping") + # Sanity check: Only run tests if the environment variable is set. + if not os.environ.get( + "GOOGLE_APPLICATION_CREDENTIALS", "" + ) and not os.environ.get("SPANNER_EMULATOR_HOST", ""): + session.skip( + "Credentials or emulator host must be set via environment variable" + ) + + system_test_exists = os.path.exists(system_test_path) + system_test_folder_exists = os.path.exists(system_test_folder_path) + # Sanity check: only run tests if found. + if not system_test_exists and not system_test_folder_exists: + session.skip("System tests were not found") + + # Use pre-release gRPC for system tests. + session.install("--pre", "grpcio") + + # Install all test dependencies, then install this package into the + # virtualenv's dist-packages. + session.install( + "django~=2.2", + "mock", + "pytest", + "google-cloud-testutils", + "-c", + constraints_path, + ) + session.install("-e", ".[tracing]", "-c", constraints_path) + + # Run py.test against the system tests. + if system_test_exists: + session.run("py.test", "--quiet", system_test_path, *session.posargs) + if system_test_folder_exists: + session.run( + "py.test", "--quiet", system_test_folder_path, *session.posargs + ) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. diff --git a/tests/system/conftest.py b/tests/system/conftest.py new file mode 100644 index 0000000000..6f83461870 --- /dev/null +++ b/tests/system/conftest.py @@ -0,0 +1,19 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import os +import django +from django.conf import settings + +# We manually designate which settings we will be using in an environment +# variable. This is similar to what occurs in the `manage.py` file. +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.system.settings") + + +# `pytest` automatically calls this function once when tests are run. +def pytest_configure(): + settings.DEBUG = False + django.setup() diff --git a/tests/system/django_spanner/__init__.py b/tests/system/django_spanner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/system/django_spanner/models.py b/tests/system/django_spanner/models.py new file mode 100644 index 0000000000..8d04a73e2d --- /dev/null +++ b/tests/system/django_spanner/models.py @@ -0,0 +1,23 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +""" +Different models used by system tests in django-spanner code. +""" +from django.db import models + + +class Author(models.Model): + first_name = models.CharField(max_length=20) + last_name = models.CharField(max_length=20) + ratting = models.DecimalField() + + +class Number(models.Model): + num = models.DecimalField() + + def __str__(self): + return str(self.num) diff --git a/tests/system/django_spanner/test_decimal.py b/tests/system/django_spanner/test_decimal.py new file mode 100644 index 0000000000..07728c538a --- /dev/null +++ b/tests/system/django_spanner/test_decimal.py @@ -0,0 +1,108 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from .models import Author, Number +from django.test import TransactionTestCase +from django.db import connection, ProgrammingError +from decimal import Decimal + + +class TestDecimal(TransactionTestCase): + @classmethod + def setUpClass(cls): + with connection.schema_editor() as editor: + # Create the tables + editor.create_model(Author) + editor.create_model(Number) + + @classmethod + def tearDownClass(cls): + with connection.schema_editor() as editor: + # delete the table + editor.delete_model(Author) + editor.delete_model(Number) + + def ratting_transform(self, value): + return value["ratting"] + + def values_transform(self, value): + return value.num + + def assertValuesEqual( + self, queryset, expected_values, transformer, ordered=True + ): + self.assertQuerysetEqual( + queryset, expected_values, transformer, ordered + ) + + def test_insert_and_search_decimal_value(self): + """ + Tests model object creation with Author model. + """ + author_kent = Author( + first_name="Arthur", + last_name="Kent", + ratting=Decimal("4.1"), + ) + author_kent.save() + qs1 = Author.objects.filter(ratting__gte=3).values("ratting") + self.assertValuesEqual( + qs1, + [Decimal("4.1")], + self.ratting_transform, + ) + # Delete data from Author table. + Author.objects.all().delete() + + def test_decimal_filter(self): + """ + Tests decimal filter query. + """ + # Insert data into Number table. + Number.objects.bulk_create( + Number(num=Decimal(i) / Decimal(10)) for i in range(10) + ) + qs1 = Number.objects.filter(num__lte=Decimal(2) / Decimal(10)) + self.assertValuesEqual( + qs1, + [Decimal(i) / Decimal(10) for i in range(3)], + self.values_transform, + ordered=False, + ) + # Delete data from Number table. + Number.objects.all().delete() + + def test_decimal_precision_limit(self): + """ + Tests decimal object precission limit. + """ + num_val = Number(num=Decimal(1) / Decimal(3)) + msg = "400 Invalid value for bind parameter a0: Expected NUMERIC." + with self.assertRaisesRegex(ProgrammingError, msg): + num_val.save() + + def test_decimal_update(self): + """ + Tests decimal object update. + """ + author_kent = Author( + first_name="Arthur", + last_name="Kent", + ratting=Decimal("4.1"), + ) + author_kent.save() + author_kent.ratting = Decimal("4.2") + author_kent.save() + qs1 = Author.objects.filter(ratting__gte=Decimal("4.2")).values( + "ratting" + ) + self.assertValuesEqual( + qs1, + [Decimal("4.2")], + self.ratting_transform, + ) + # Delete data from Author table. + Author.objects.all().delete() diff --git a/tests/system/settings.py b/tests/system/settings.py new file mode 100644 index 0000000000..b2eef6a3c8 --- /dev/null +++ b/tests/system/settings.py @@ -0,0 +1,46 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +DEBUG = True +USE_TZ = True + +INSTALLED_APPS = [ + "django_spanner", # Must be the first entry + "django.contrib.contenttypes", + "django.contrib.auth", + "django.contrib.sites", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "tests", +] + +TIME_ZONE = "UTC" + +DATABASES = { + "default": { + "ENGINE": "django_spanner", + "PROJECT": "appdev-soda-spanner-staging", + "INSTANCE": "django-test-instance", + "NAME": "django-test-db", + } +} +SECRET_KEY = "spanner env secret key" + +PASSWORD_HASHERS = [ + "django.contrib.auth.hashers.MD5PasswordHasher", +] + +SITE_ID = 1 + +CONN_MAX_AGE = 60 + +ENGINE = "django_spanner" +PROJECT = "emulator-local" +INSTANCE = "django-test-instance" +NAME = "django-test-db" +OPTIONS = {} +AUTOCOMMIT = True