Skip to content

fix: Fixing test for literals due to change in sqlalchemy core tests #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/sync-repo-settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ branchProtectionRules:
- 'unit'
- 'compliance_tests_13'
- 'compliance_tests_14'
- 'compliance_tests_20'
- 'migration_tests'
- 'cla/google'
- 'Kokoro'
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
format_type,
)
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import elements
from sqlalchemy import ForeignKeyConstraint, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
Expand Down Expand Up @@ -314,6 +315,12 @@ def render_literal_value(self, value, type_):
in string. Override the method to add additional escape before using it to
generate a SQL statement.
"""
if value is None and not type_.should_evaluate_none:
# issue #10535 - handle NULL in the compiler without placing
# this onto each type, except for "evaluate None" types
# (e.g. JSON)
return self.process(elements.Null._instance())

raw = ["\\", "'", '"', "\n", "\t", "\r"]
if isinstance(value, str) and any(single in value for single in raw):
value = 'r"""{}"""'.format(value)
Expand Down
12 changes: 6 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compliance_test_13(session):
)

session.install("mock")
session.install("-e", ".[tracing]")
session.install(".[tracing]")
session.run("pip", "install", "sqlalchemy>=1.1.13,<=1.3.24", "--force-reinstall")
session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall")
session.run("pip", "install", "opentelemetry-sdk<=1.10", "--force-reinstall")
Expand Down Expand Up @@ -191,7 +191,7 @@ def compliance_test_14(session):
)

session.install("mock")
session.install("-e", ".[tracing]")
session.install(".[tracing]")
session.run("pip", "install", "sqlalchemy>=1.4,<2.0", "--force-reinstall")
session.run("python", "create_test_database.py")
session.run(
Expand Down Expand Up @@ -231,7 +231,7 @@ def compliance_test_20(session):
)

session.install("mock")
session.install("-e", ".[tracing]")
session.install(".[tracing]")
session.run("pip", "install", "opentelemetry-api<=1.10", "--force-reinstall")
session.run("python", "create_test_database.py")

Expand All @@ -257,7 +257,7 @@ def unit(session):
# Run SQLAlchemy dialect compliance test suite with OpenTelemetry.
session.install("pytest")
session.install("mock")
session.install("-e", ".")
session.install(".")
session.install("opentelemetry-api==1.1.0")
session.install("opentelemetry-sdk==1.1.0")
session.install("opentelemetry-instrumentation==0.20b0")
Expand Down Expand Up @@ -292,7 +292,7 @@ def _migration_test(session):
session.run("pip", "install", "sqlalchemy>=1.3.11,<2.0", "--force-reinstall")

session.install("pytest")
session.install("-e", ".")
session.install(".")
session.install("alembic")

session.run("python", "create_test_database.py")
Expand Down Expand Up @@ -360,7 +360,7 @@ def snippets(session):
session.install(
"git+https://github.com/googleapis/python-spanner.git#egg=google-cloud-spanner"
)
session.install("-e", ".")
session.install(".")
session.run("python", "create_test_database.py")
session.run(
"py.test",
Expand Down
63 changes: 63 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,72 @@

import pytest
from sqlalchemy.dialects import registry
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.sql.elements import literal

registry.register("spanner", "google.cloud.sqlalchemy_spanner", "SpannerDialect")

pytest.register_assert_rewrite("sqlalchemy.testing.assertions")

from sqlalchemy.testing.plugin.pytestplugin import * # noqa: E402, F401, F403


@pytest.fixture
def literal_round_trip_spanner(metadata, connection):
# for literal, we test the literal render in an INSERT
# into a typed column. we can then SELECT it back as its
# official type;

def run(
type_,
input_,
output,
filter_=None,
compare=None,
support_whereclause=True,
):
t = Table("t", metadata, Column("x", type_))
t.create(connection)

for value in input_:
ins = t.insert().values(x=literal(value, type_, literal_execute=True))
connection.execute(ins)

if support_whereclause:
if compare:
stmt = t.select().where(
t.c.x
== literal(
compare,
type_,
literal_execute=True,
),
t.c.x
== literal(
input_[0],
type_,
literal_execute=True,
),
)
else:
stmt = t.select().where(
t.c.x
== literal(
compare if compare is not None else input_[0],
type_,
literal_execute=True,
)
)
else:
stmt = t.select()

rows = connection.execute(stmt).all()
assert rows, "No rows returned"
for row in rows:
value = row[0]
if filter_ is not None:
value = filter_(value)
assert value in output

return run
56 changes: 45 additions & 11 deletions test/test_suite_20.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@
UnicodeTextTest as _UnicodeTextTest,
_UnicodeFixture as __UnicodeFixture,
) # noqa: F401, F403
from test._helpers import get_db_url, get_project
from test._helpers import (
get_db_url,
get_project,
)

config.test_schema = ""

Expand All @@ -162,7 +165,7 @@ class BooleanTest(_BooleanTest):
def test_render_literal_bool(self):
pass

def test_render_literal_bool_true(self, literal_round_trip):
def test_render_literal_bool_true(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:

Expand All @@ -171,9 +174,9 @@ def test_render_literal_bool_true(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(Boolean(), [True], [True])
literal_round_trip_spanner(Boolean(), [True], [True])

def test_render_literal_bool_false(self, literal_round_trip):
def test_render_literal_bool_false(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:

Expand All @@ -182,7 +185,7 @@ def test_render_literal_bool_false(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(Boolean(), [False], [False])
literal_round_trip_spanner(Boolean(), [False], [False])

@pytest.mark.skip("Not supported by Cloud Spanner")
def test_whereclause(self):
Expand Down Expand Up @@ -2003,6 +2006,9 @@ def test_huge_int_auto_accommodation(self, connection, intvalue):
intvalue,
)

def test_literal(self, literal_round_trip_spanner):
literal_round_trip_spanner(Integer, [5], [5])


class _UnicodeFixture(__UnicodeFixture):
@classmethod
Expand Down Expand Up @@ -2189,6 +2195,19 @@ def test_dont_truncate_rightside(
args[1],
)

def test_literal(self, literal_round_trip_spanner):
# note that in Python 3, this invokes the Unicode
# datatype for the literal part because all strings are unicode
literal_round_trip_spanner(String(40), ["some text"], ["some text"])

def test_literal_quoting(self, literal_round_trip_spanner):
data = """some 'text' hey "hi there" that's text"""
literal_round_trip_spanner(String(40), [data], [data])

def test_literal_backslashes(self, literal_round_trip_spanner):
data = r"backslash one \ backslash two \\ end"
literal_round_trip_spanner(String(40), [data], [data])


class TextTest(_TextTest):
@classmethod
Expand Down Expand Up @@ -2224,6 +2243,21 @@ def test_text_empty_strings(self, connection):
def test_text_null_strings(self, connection):
pass

def test_literal(self, literal_round_trip_spanner):
literal_round_trip_spanner(Text, ["some text"], ["some text"])

def test_literal_quoting(self, literal_round_trip_spanner):
data = """some 'text' hey "hi there" that's text"""
literal_round_trip_spanner(Text, [data], [data])

def test_literal_backslashes(self, literal_round_trip_spanner):
data = r"backslash one \ backslash two \\ end"
literal_round_trip_spanner(Text, [data], [data])

def test_literal_percentsigns(self, literal_round_trip_spanner):
data = r"percent % signs %% percent"
literal_round_trip_spanner(Text, [data], [data])


class NumericTest(_NumericTest):
@testing.fixture
Expand Down Expand Up @@ -2254,7 +2288,7 @@ def run(type_, input_, output, filter_=None, check_scale=False):
return run

@emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric(self, literal_round_trip):
def test_render_literal_numeric(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:

Expand All @@ -2263,14 +2297,14 @@ def test_render_literal_numeric(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(
literal_round_trip_spanner(
Numeric(precision=8, scale=4),
[decimal.Decimal("15.7563")],
[decimal.Decimal("15.7563")],
)

@emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric_asfloat(self, literal_round_trip):
def test_render_literal_numeric_asfloat(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:

Expand All @@ -2279,13 +2313,13 @@ def test_render_literal_numeric_asfloat(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(
literal_round_trip_spanner(
Numeric(precision=8, scale=4, asdecimal=False),
[decimal.Decimal("15.7563")],
[15.7563],
)

def test_render_literal_float(self, literal_round_trip):
def test_render_literal_float(self, literal_round_trip_spanner):
"""
SPANNER OVERRIDE:

Expand All @@ -2294,7 +2328,7 @@ def test_render_literal_float(self, literal_round_trip):
following insertions will fail with `Row [] already exists".
Overriding the test to avoid the same failure.
"""
literal_round_trip(
literal_round_trip_spanner(
Float(4),
[decimal.Decimal("15.7563")],
[15.7563],
Expand Down