diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index fd39852d..c83e8a4f 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -11,6 +11,7 @@ branchProtectionRules: - 'unit' - 'compliance_tests_13' - 'compliance_tests_14' + - 'compliance_tests_20' - 'migration_tests' - 'cla/google' - 'Kokoro' diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 7f0b44a9..47176699 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -23,6 +23,7 @@ format_type, ) from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.sql import elements from sqlalchemy import ForeignKeyConstraint, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext @@ -314,6 +315,12 @@ def render_literal_value(self, value, type_): in string. Override the method to add additional escape before using it to generate a SQL statement. """ + if value is None and not type_.should_evaluate_none: + # issue #10535 - handle NULL in the compiler without placing + # this onto each type, except for "evaluate None" types + # (e.g. JSON) + return self.process(elements.Null._instance()) + raw = ["\\", "'", '"', "\n", "\t", "\r"] if isinstance(value, str) and any(single in value for single in raw): value = 'r"""{}"""'.format(value) diff --git a/noxfile.py b/noxfile.py index 77df05f7..614e593d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -145,7 +145,7 @@ def compliance_test_13(session): ) session.install("mock") - session.install("-e", ".[tracing]") + session.install(".[tracing]") session.run("pip", "install", "sqlalchemy>=1.1.13,<=1.3.24", "--force-reinstall") session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall") session.run("pip", "install", "opentelemetry-sdk<=1.10", "--force-reinstall") @@ -191,7 +191,7 @@ def compliance_test_14(session): ) session.install("mock") - session.install("-e", ".[tracing]") + session.install(".[tracing]") session.run("pip", "install", "sqlalchemy>=1.4,<2.0", "--force-reinstall") session.run("python", "create_test_database.py") session.run( @@ -231,7 +231,7 @@ def compliance_test_20(session): ) session.install("mock") - session.install("-e", ".[tracing]") + session.install(".[tracing]") session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall") session.run("python", "create_test_database.py") @@ -257,7 +257,7 @@ def unit(session): # Run SQLAlchemy dialect compliance test suite with OpenTelemetry. session.install("pytest") session.install("mock") - session.install("-e", ".") + session.install(".") session.install("opentelemetry-api==1.1.0") session.install("opentelemetry-sdk==1.1.0") session.install("opentelemetry-instrumentation==0.20b0") @@ -292,7 +292,7 @@ def _migration_test(session): session.run("pip", "install", "sqlalchemy>=1.3.11,<2.0", "--force-reinstall") session.install("pytest") - session.install("-e", ".") + session.install(".") session.install("alembic") session.run("python", "create_test_database.py") @@ -360,7 +360,7 @@ def snippets(session): session.install( "git+https://github.com/googleapis/python-spanner.git#egg=google-cloud-spanner" ) - session.install("-e", ".") + session.install(".") session.run("python", "create_test_database.py") session.run( "py.test", diff --git a/test/conftest.py b/test/conftest.py index 35767cf2..3b01359d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -16,9 +16,72 @@ import pytest from sqlalchemy.dialects import registry +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table +from sqlalchemy.sql.elements import literal registry.register("spanner", "google.cloud.sqlalchemy_spanner", "SpannerDialect") pytest.register_assert_rewrite("sqlalchemy.testing.assertions") from sqlalchemy.testing.plugin.pytestplugin import * # noqa: E402, F401, F403 + + +@pytest.fixture +def literal_round_trip_spanner(metadata, connection): + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as its + # official type; + + def run( + type_, + input_, + output, + filter_=None, + compare=None, + support_whereclause=True, + ): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + for value in input_: + ins = t.insert().values(x=literal(value, type_, literal_execute=True)) + connection.execute(ins) + + if support_whereclause: + if compare: + stmt = t.select().where( + t.c.x + == literal( + compare, + type_, + literal_execute=True, + ), + t.c.x + == literal( + input_[0], + type_, + literal_execute=True, + ), + ) + else: + stmt = t.select().where( + t.c.x + == literal( + compare if compare is not None else input_[0], + type_, + literal_execute=True, + ) + ) + else: + stmt = t.select() + + rows = connection.execute(stmt).all() + assert rows, "No rows returned" + for row in rows: + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output + + return run diff --git a/test/test_suite_20.py b/test/test_suite_20.py index 50958aaa..8d6d8113 100644 --- a/test/test_suite_20.py +++ b/test/test_suite_20.py @@ -149,7 +149,10 @@ UnicodeTextTest as _UnicodeTextTest, _UnicodeFixture as __UnicodeFixture, ) # noqa: F401, F403 -from test._helpers import get_db_url, get_project +from test._helpers import ( + get_db_url, + get_project, +) config.test_schema = "" @@ -162,7 +165,7 @@ class BooleanTest(_BooleanTest): def test_render_literal_bool(self): pass - def test_render_literal_bool_true(self, literal_round_trip): + def test_render_literal_bool_true(self, literal_round_trip_spanner): """ SPANNER OVERRIDE: @@ -171,9 +174,9 @@ def test_render_literal_bool_true(self, literal_round_trip): following insertions will fail with `Row [] already exists". Overriding the test to avoid the same failure. """ - literal_round_trip(Boolean(), [True], [True]) + literal_round_trip_spanner(Boolean(), [True], [True]) - def test_render_literal_bool_false(self, literal_round_trip): + def test_render_literal_bool_false(self, literal_round_trip_spanner): """ SPANNER OVERRIDE: @@ -182,7 +185,7 @@ def test_render_literal_bool_false(self, literal_round_trip): following insertions will fail with `Row [] already exists". Overriding the test to avoid the same failure. """ - literal_round_trip(Boolean(), [False], [False]) + literal_round_trip_spanner(Boolean(), [False], [False]) @pytest.mark.skip("Not supported by Cloud Spanner") def test_whereclause(self): @@ -2003,6 +2006,9 @@ def test_huge_int_auto_accommodation(self, connection, intvalue): intvalue, ) + def test_literal(self, literal_round_trip_spanner): + literal_round_trip_spanner(Integer, [5], [5]) + class _UnicodeFixture(__UnicodeFixture): @classmethod @@ -2189,6 +2195,19 @@ def test_dont_truncate_rightside( args[1], ) + def test_literal(self, literal_round_trip_spanner): + # note that in Python 3, this invokes the Unicode + # datatype for the literal part because all strings are unicode + literal_round_trip_spanner(String(40), ["some text"], ["some text"]) + + def test_literal_quoting(self, literal_round_trip_spanner): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip_spanner(String(40), [data], [data]) + + def test_literal_backslashes(self, literal_round_trip_spanner): + data = r"backslash one \ backslash two \\ end" + literal_round_trip_spanner(String(40), [data], [data]) + class TextTest(_TextTest): @classmethod @@ -2224,6 +2243,21 @@ def test_text_empty_strings(self, connection): def test_text_null_strings(self, connection): pass + def test_literal(self, literal_round_trip_spanner): + literal_round_trip_spanner(Text, ["some text"], ["some text"]) + + def test_literal_quoting(self, literal_round_trip_spanner): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip_spanner(Text, [data], [data]) + + def test_literal_backslashes(self, literal_round_trip_spanner): + data = r"backslash one \ backslash two \\ end" + literal_round_trip_spanner(Text, [data], [data]) + + def test_literal_percentsigns(self, literal_round_trip_spanner): + data = r"percent % signs %% percent" + literal_round_trip_spanner(Text, [data], [data]) + class NumericTest(_NumericTest): @testing.fixture @@ -2254,7 +2288,7 @@ def run(type_, input_, output, filter_=None, check_scale=False): return run @emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric(self, literal_round_trip): + def test_render_literal_numeric(self, literal_round_trip_spanner): """ SPANNER OVERRIDE: @@ -2263,14 +2297,14 @@ def test_render_literal_numeric(self, literal_round_trip): following insertions will fail with `Row [] already exists". Overriding the test to avoid the same failure. """ - literal_round_trip( + literal_round_trip_spanner( Numeric(precision=8, scale=4), [decimal.Decimal("15.7563")], [decimal.Decimal("15.7563")], ) @emits_warning(r".*does \*not\* support Decimal objects natively") - def test_render_literal_numeric_asfloat(self, literal_round_trip): + def test_render_literal_numeric_asfloat(self, literal_round_trip_spanner): """ SPANNER OVERRIDE: @@ -2279,13 +2313,13 @@ def test_render_literal_numeric_asfloat(self, literal_round_trip): following insertions will fail with `Row [] already exists". Overriding the test to avoid the same failure. """ - literal_round_trip( + literal_round_trip_spanner( Numeric(precision=8, scale=4, asdecimal=False), [decimal.Decimal("15.7563")], [15.7563], ) - def test_render_literal_float(self, literal_round_trip): + def test_render_literal_float(self, literal_round_trip_spanner): """ SPANNER OVERRIDE: @@ -2294,7 +2328,7 @@ def test_render_literal_float(self, literal_round_trip): following insertions will fail with `Row [] already exists". Overriding the test to avoid the same failure. """ - literal_round_trip( + literal_round_trip_spanner( Float(4), [decimal.Decimal("15.7563")], [15.7563],