Skip to content

Commit

Permalink
More changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Apr 10, 2024
1 parent 07cc789 commit c44439a
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 168 deletions.
6 changes: 6 additions & 0 deletions django_spanner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,9 @@ def is_usable(self):
return False

return True

def _start_transaction_under_autocommit(self):
"""
Start a transaction explicitly in autocommit mode.
"""
self.connection.cursor().execute("BEGIN")
23 changes: 13 additions & 10 deletions django_spanner/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"contenttypes_tests.test_models.ContentTypesTests.test_app_labeled_name",
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_lookup_name_sql_injection",
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_lookup_name_sql_injection",
#"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_time_comparison",
#"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_time_comparison",
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_extract_lookup_name_sql_injection",
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_lookup_name_sql_injection",
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_ambiguous_and_invalid_times",
Expand All @@ -531,16 +529,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"schema.tests.SchemaTests.test_add_auto_field",
"schema.tests.SchemaTests.test_alter_null_with_default_value_deferred_constraints",
"schema.tests.SchemaTests.test_autofield_to_o2o",
#"schema.tests.SchemaTests.test_add_field_durationfield_with_default",
#"backends.tests.LastExecutedQueryTest.test_last_executed_query_dict_overlap_keys",
#"backends.tests.LastExecutedQueryTest.test_last_executed_query_with_duplicate_params",
#"backends.tests.BackendTestCase.test_queries_logger",
"backends.tests.BackendTestCase.test_queries_bare_where",
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
"expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
#"generic_relations.tests.GenericRelationsTests.test_unsaved_generic_foreign_key_parent_bulk_create",
#"generic_relations.tests.GenericRelationsTests.test_unsaved_generic_foreign_key_parent_save",
#"get_or_create.tests.UpdateOrCreateTests.test_update_only_defaults_and_pre_save_fields_when_local_fields",
"inspectdb.tests.InspectDBTestCase.test_same_relations",
"migrations.test_operations.OperationTests.test_alter_field_pk_fk_char_to_int",
"migrations.test_operations.OperationTests.test_alter_field_with_func_unique_constraint",
Expand All @@ -550,6 +541,18 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"migrations.test_operations.OperationTests.test_rename_field_unique_together",
"migrations.test_operations.OperationTests.test_rename_model_with_db_table_rename_m2m",
"migrations.test_operations.OperationTests.test_rename_model_with_m2m_models_in_different_apps_with_same_name",
"delete.tests.DeletionTests.test_pk_none",
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_time_comparison",
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_time_comparison",
"backends.tests.LastExecutedQueryTest.test_last_executed_query_dict_overlap_keys",
"backends.tests.LastExecutedQueryTest.test_last_executed_query_with_duplicate_params",
"backends.tests.BackendTestCase.test_queries_logger",
"generic_relations.tests.GenericRelationsTests.test_unsaved_generic_foreign_key_parent_bulk_create",
"generic_relations.tests.GenericRelationsTests.test_unsaved_generic_foreign_key_parent_save",
"schema.tests.SchemaTests.test_add_field_durationfield_with_default",
"delete.tests.DeletionTests.test_only_referenced_fields_selected",
"bulk_create.tests.BulkCreateTests.test_explicit_batch_size_efficiency",
"get_or_create.tests.UpdateOrCreateTests.test_update_only_defaults_and_pre_save_fields_when_local_fields",
)

if os.environ.get("SPANNER_EMULATOR_HOST", None):
Expand Down Expand Up @@ -2071,7 +2074,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"auth_tests.test_management.CreatesuperuserManagementCommandTestCase.test_ignore_environment_variable_non_interactive", # noqa
"auth_tests.test_management.GetDefaultUsernameTestCase.test_with_database", # noqa
"auth_tests.test_management.MultiDBCreatesuperuserTestCase.test_createsuperuser_command_suggested_username_with_database_option", # noqa
"auth_tests.test_middleware.TestAuthenticationMiddleware.test_no_password_change_does_not_invalidate_legacy_session", # noqa
"auth_tests.test_middleware.TestAuthenticationMiddleware.test_no_session", # noqa
"auth_tests.test_middleware.TestAuthenticationMiddleware.test_session_default_hashing_algorithm", # noqa
"auth_tests.test_models.UserManagerTestCase.test_runpython_manager_methods", # noqa
Expand Down Expand Up @@ -2153,6 +2155,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"auth_tests.test_forms.UserCreationFormTest.test_validates_password", # noqa
"auth_tests.test_forms.UserCreationFormTest.test_html_autocomplete_attributes", # noqa
"auth_tests.test_forms.UserCreationFormTest.test_username_field_autocapitalize_none", # noqa
"auth_tests.test_middleware.TestAuthenticationMiddleware.test_no_password_change_does_not_invalidate_legacy_session", # noqa
)
if USING_DJANGO_4:
skip_tests += (
Expand Down
16 changes: 8 additions & 8 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def lint_setup_py(session):
)


def default(session, django_version="2.2"):
def default(session, django_version="3.2"):
# Install all test dependencies, then install this package in-place.
session.install(
"django~={}".format(django_version),
Expand All @@ -75,7 +75,7 @@ def default(session, django_version="2.2"):
"pytest",
"pytest-cov",
"coverage",
"sqlparse==0.3.0",
"sqlparse==0.3.1",
"google-cloud-spanner>=3.13.0",
"opentelemetry-api==1.1.0",
"opentelemetry-sdk==1.1.0",
Expand All @@ -101,13 +101,13 @@ def default(session, django_version="2.2"):
@nox.session(python=UNIT_TEST_PYTHON_VERSIONS)
def unit(session):
"""Run the unit test suite."""
print("Unit tests with django 2.2")
default(session)
print("Unit tests with django 3.2")
default(session, django_version="3.2")
default(session)
print("Unit tests with django 4.2")
default(session, django_version="4.2")


def system_test(session, django_version="2.2"):
def system_test(session, django_version="3.2"):
"""Run the system test suite."""
constraints_path = str(
CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
Expand Down Expand Up @@ -157,8 +157,8 @@ def system_test(session, django_version="2.2"):

@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
def system(session):
# print("System tests with django 2.2")
# system_test(session)
print("System tests with django 3.2")
system_test(session)
print("System tests with django 4.2")
system_test(session, django_version="4.2")

Expand Down
181 changes: 125 additions & 56 deletions tests/unit/django_spanner/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.db.utils import DatabaseError
from django_spanner.compiler import SQLCompiler
from django.db.models.query import QuerySet
from django_spanner import USING_DJANGO_3, USING_DJANGO_4
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass
from .models import Number

Expand Down Expand Up @@ -38,14 +39,24 @@ def test_get_combinator_sql_all_union_sql_generated(self):

compiler = SQLCompiler(qs4.query, self.connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", True)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
if USING_DJANGO_3:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
elif USING_DJANGO_4:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num "
+ "AS col1 FROM tests_number WHERE tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_distinct_union_sql_generated(self):
Expand All @@ -59,15 +70,26 @@ def test_get_combinator_sql_distinct_union_sql_generated(self):

compiler = SQLCompiler(qs4.query, self.connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
if USING_DJANGO_3:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
elif USING_DJANGO_4:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT "
+ "tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_difference_all_sql_generated(self):
Expand All @@ -81,14 +103,24 @@ def test_get_combinator_sql_difference_all_sql_generated(self):
compiler = SQLCompiler(qs4.query, self.connection, "default")
sql_compiled, params = compiler.get_combinator_sql("difference", True)

self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
if USING_DJANGO_3:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
elif USING_DJANGO_4:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num "
+ "AS col1 FROM tests_number WHERE tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_difference_distinct_sql_generated(self):
Expand All @@ -102,15 +134,26 @@ def test_get_combinator_sql_difference_distinct_sql_generated(self):
compiler = SQLCompiler(qs4.query, self.connection, "default")
sql_compiled, params = compiler.get_combinator_sql("difference", False)

self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
if USING_DJANGO_3:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
elif USING_DJANGO_4:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT "
+ "tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_union_and_difference_query_together(self):
Expand All @@ -124,17 +167,30 @@ def test_get_combinator_sql_union_and_difference_query_together(self):

compiler = SQLCompiler(qs4.query, self.connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
if USING_DJANGO_3:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
elif USING_DJANGO_4:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num AS col1 FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
self.assertEqual(params, [1, 8, 10])

def test_get_combinator_sql_parentheses_in_compound_not_supported(self):
Expand All @@ -151,17 +207,30 @@ def test_get_combinator_sql_parentheses_in_compound_not_supported(self):
compiler = SQLCompiler(qs4.query, self.connection, "default")
compiler.connection.features.supports_parentheses_in_compound = False
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
if USING_DJANGO_3:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
elif USING_DJANGO_4:
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num AS col1 FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num AS col1 FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
self.assertEqual(params, [1, 8, 10])

def test_get_combinator_sql_empty_queryset_raises_exception(self):
Expand Down
Loading

0 comments on commit c44439a

Please sign in to comment.