Skip to content

Commit

Permalink
Parametrize SELECT queries (#1777)
Browse files Browse the repository at this point in the history
* Parametrize SELECT queries

* Make sure _execute() uses the same query as returned from sql()

* Parametrize .values, .values_list, .exists and .count queries

* Fix Postgres issues

* Add params_inline arg to QuerySet.sql()

* Use pypika-tortoise 0.3.0
  • Loading branch information
henadzit authored Nov 21, 2024
1 parent 916d6cb commit 7f077c1
Show file tree
Hide file tree
Showing 23 changed files with 492 additions and 231 deletions.
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pypika-tortoise = "^0.2.2"
pypika-tortoise = "^0.3.0"
iso8601 = "^2.1.0"
aiosqlite = ">=0.16.0, <0.21.0"
pytz = "*"
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_mysql_func_rand(self):
@test.requireCapability(dialect="mysql")
async def test_mysql_func_rand_with_seed(self):
sql = IntFields.all().annotate(randnum=Rand(0)).values("intnum", "randnum").sql()
expected_sql = "SELECT `intnum` `intnum`,RAND(0) `randnum` FROM `intfields`"
expected_sql = "SELECT `intnum` `intnum`,RAND(%s) `randnum` FROM `intfields`"
self.assertEqual(sql, expected_sql)

@test.requireCapability(dialect="postgres")
Expand Down
73 changes: 59 additions & 14 deletions tests/test_case_when.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ async def asyncSetUp(self):

async def test_single_when(self):
category = Case(When(intnum__gte=8, then="big"), default="default")
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -27,7 +32,12 @@ async def test_multi_when(self):
category = Case(
When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="default"
)
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -38,7 +48,12 @@ async def test_multi_when(self):

async def test_q_object_when(self):
category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default")
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -49,7 +64,12 @@ async def test_q_object_when(self):

async def test_F_then(self):
category = Case(When(intnum__gte=8, then=F("intnum_null")), default="default")
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -61,7 +81,12 @@ async def test_F_then(self):
async def test_AE_then(self):
# AE: ArithmeticExpression
category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default")
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -72,7 +97,12 @@ async def test_AE_then(self):

async def test_func_then(self):
category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default")
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -83,7 +113,12 @@ async def test_func_then(self):

async def test_F_default(self):
category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null"))
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -95,7 +130,12 @@ async def test_F_default(self):
async def test_AE_default(self):
# AE: ArithmeticExpression
category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1)
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -106,7 +146,12 @@ async def test_AE_default(self):

async def test_func_default(self):
category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10))
sql = IntFields.all().annotate(category=category).values("intnum", "category").sql()
sql = (
IntFields.all()
.annotate(category=category)
.values("intnum", "category")
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -124,7 +169,7 @@ async def test_case_when_in_where(self):
.annotate(category=category)
.filter(category__in=["big", "small"])
.values("intnum")
.sql()
.sql(params_inline=True)
)
dialect = self.db.schema_generator.DIALECT
if dialect == "mysql":
Expand All @@ -139,7 +184,7 @@ async def test_annotation_in_when_annotation(self):
.annotate(intnum_plus_1=F("intnum") + 1)
.annotate(bigger_than_10=Case(When(Q(intnum_plus_1__gte=10), then=True), default=False))
.values("id", "intnum", "intnum_plus_1", "bigger_than_10")
.sql()
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
Expand All @@ -155,7 +200,7 @@ async def test_func_annotation_in_when_annotation(self):
.annotate(intnum_col=Coalesce("intnum", 0))
.annotate(is_zero=Case(When(Q(intnum_col=0), then=True), default=False))
.values("id", "intnum_col", "is_zero")
.sql()
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
Expand All @@ -172,7 +217,7 @@ async def test_case_when_in_group_by(self):
.annotate(count=Count("id"))
.group_by("is_zero")
.values("is_zero", "count")
.sql()
.sql(params_inline=True)
)

dialect = self.db.schema_generator.DIALECT
Expand All @@ -188,4 +233,4 @@ async def test_unknown_field_in_when_annotation(self):
with self.assertRaisesRegex(FieldError, "Unknown filter param 'unknown'.+"):
IntFields.all().annotate(intnum_col=Coalesce("intnum", 0)).annotate(
is_zero=Case(When(Q(unknown=0), then="1"), default="2")
).sql()
).sql(params_inline=True)
8 changes: 8 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ async def test_between_and(self):
[Decimal("1.2345")],
)

async def test_in(self):
self.assertEqual(
await DecimalFields.filter(
decimal__in=[Decimal("1.2345"), Decimal("1000")]
).values_list("decimal", flat=True),
[Decimal("1.2345")],
)


class TestCharFkFieldFilters(test.TestCase):
async def asyncSetUp(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_fuzz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tests.testmodels import CharFields
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.functions import Upper

DODGY_STRINGS = [
"a/",
Expand All @@ -9,6 +10,11 @@
"a\\x39",
"a'",
'"',
'""',
"'",
"''",
"\\_",
"\\\\_",
"‘a",
"a’",
"‘a’",
Expand Down Expand Up @@ -134,3 +140,12 @@ async def test_char_fuzz(self):
)
self.assertEqual(obj1.pk, obj5.pk)
self.assertEqual(char, obj5.char)

# Filter by a function
obj6 = (
await CharFields.annotate(upper_char=Upper("char"))
.filter(id=obj1.pk, upper_char=Upper("char"))
.first()
)
self.assertEqual(obj1.pk, obj6.pk)
self.assertEqual(char, obj6.char)
6 changes: 3 additions & 3 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,14 @@ async def test_index_access(self):

async def test_index_badval(self):
with self.assertRaises(ObjectDoesNotExistError) as cm:
await self.cls[100000]
await self.cls[32767]
the_exception = cm.exception
# For compatibility reasons this should be an instance of KeyError
self.assertIsInstance(the_exception, KeyError)
self.assertIs(the_exception.model, self.cls)
self.assertEqual(the_exception.pk_name, "id")
self.assertEqual(the_exception.pk_val, 100000)
self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=100000")
self.assertEqual(the_exception.pk_val, 32767)
self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=32767")

async def test_index_badtype(self):
with self.assertRaises(ObjectDoesNotExistError) as cm:
Expand Down
Loading

0 comments on commit 7f077c1

Please sign in to comment.