Skip to content

Commit

Permalink
Translate to NULLIF
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Dec 13, 2024
1 parent c099cef commit 70b2aa5
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 0 deletions.
2 changes: 2 additions & 0 deletions EFCore.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ The .NET Foundation licenses this file to you under the MIT license.
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery_0027s/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=transactionality/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalescing/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unconfigured/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unequality/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=unignore/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=fixup/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=attacher/@EntryIndexedValue">True</s:Boolean>
Expand Down
47 changes: 47 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,31 @@ public virtual SqlExpression Case(
elseResult = lastCase.ElseResult;
}

// Optimize:
// a == b ? null : a -> NULLIF(a, b)
// a != b ? a : null -> NULLIF(a, b)
if (operand is null
&& typeMappedWhenClauses is
[
{
Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary,
Result: var result
}
])
{
switch (binary.OperatorType)
{
case ExpressionType.Equal
when result is SqlConstantExpression { Value: null }
&& elseResult is not null
&& TryTranslateToNullIf(elseResult, out var nullIfTranslation):
case ExpressionType.NotEqual
when elseResult is null or SqlConstantExpression { Value: null }
&& TryTranslateToNullIf(result, out nullIfTranslation):
return nullIfTranslation;
}
}

return existingExpression is CaseExpression expr
&& operand == expr.Operand
&& typeMappedWhenClauses.SequenceEqual(expr.WhenClauses)
Expand All @@ -837,6 +862,28 @@ bool IsSkipped(CaseWhenClause clause)

bool IsMatched(CaseWhenClause clause)
=> operand is null && clause.Test is SqlConstantExpression { Value: true };

bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out SqlExpression? nullIfTranslation)
{
var (left, right) = (binary.Left, binary.Right);

if (left.Equals(conditionalResult))
{
nullIfTranslation = Function(
"NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping);
return true;
}

if (right.Equals(conditionalResult))
{
nullIfTranslation = Function(
"NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping);
return true;
}

nullIfTranslation = null;
return false;
}
}

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,66 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp

#endregion Compare

#region Uncoalescing conditional / NullIf

public override Task Uncoalescing_conditional_with_equality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_equality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1)
""");
});

public override Task Uncoalescing_conditional_with_equality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_equality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1)
""");
});

public override Task Uncoalescing_conditional_with_unequality_left(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_unequality_left(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1)
""");
});

public override Task Uncoalescing_conditional_with_inequality_right(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Uncoalescing_conditional_with_inequality_right(a);

AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1)
""");
});

#endregion Uncoalescing conditional / NullIf

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,4 +429,38 @@ await AssertQuery(
}

#endregion

#region Uncoalescing conditional

// In relational providers, x == a ? null : x is translated to SQL NULLIF

[Theory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_equality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int == 9 ? null : x.Int) > 1));

[Theory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_equality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 == x.Int ? null : x.Int) > 1));

[Theory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_unequality_left(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int != 9 ? x.Int : null) > 1));

[Theory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Uncoalescing_conditional_with_inequality_right(bool async)
=> AssertQuery(
async,
cs => cs.Set<BasicTypesEntity>().Where(x => (9 != x.Int ? x.Int : null) > 1));

#endregion Uncoalescing conditional
}
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,62 @@ FROM [BasicTypesEntities] AS [b]

#endregion Compare

#region Uncoalescing conditional / NullIf

public override async Task Uncoalescing_conditional_with_equality_left(bool async)
{
await base.Uncoalescing_conditional_with_equality_left(async);

AssertSql(
"""
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
FROM [BasicTypesEntities] AS [b]
WHERE NULLIF([b].[Int], 9) > 1
""");
}

public override async Task Uncoalescing_conditional_with_equality_right(bool async)
{
await base.Uncoalescing_conditional_with_equality_right(async);

AssertSql(
"""
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
FROM [BasicTypesEntities] AS [b]
WHERE NULLIF([b].[Int], 9) > 1
""");
}

public override async Task Uncoalescing_conditional_with_unequality_left(bool async)
{
await base.Uncoalescing_conditional_with_unequality_left(async);

AssertSql(
"""
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
FROM [BasicTypesEntities] AS [b]
WHERE CASE
WHEN [b].[Int] <> 9 THEN [b].[Int]
END > 1
""");
}

public override async Task Uncoalescing_conditional_with_inequality_right(bool async)
{
await base.Uncoalescing_conditional_with_inequality_right(async);

AssertSql(
"""
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
FROM [BasicTypesEntities] AS [b]
WHERE CASE
WHEN 9 <> [b].[Int] THEN [b].[Int]
END > 1
""");
}

#endregion Uncoalescing conditional / NullIf

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,58 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp

#endregion Compare

#region Uncoalescing conditional / NullIf

public override async Task Uncoalescing_conditional_with_equality_left(bool async)
{
await base.Uncoalescing_conditional_with_equality_left(async);

AssertSql(
"""
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
FROM "BasicTypesEntities" AS "b"
WHERE NULLIF("b"."Int", 9) > 1
""");
}

public override async Task Uncoalescing_conditional_with_equality_right(bool async)
{
await base.Uncoalescing_conditional_with_equality_right(async);

AssertSql(
"""
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
FROM "BasicTypesEntities" AS "b"
WHERE NULLIF("b"."Int", 9) > 1
""");
}

public override async Task Uncoalescing_conditional_with_unequality_left(bool async)
{
await base.Uncoalescing_conditional_with_unequality_left(async);

AssertSql(
"""
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
FROM "BasicTypesEntities" AS "b"
WHERE NULLIF("b"."Int", 9) > 1
""");
}

public override async Task Uncoalescing_conditional_with_inequality_right(bool async)
{
await base.Uncoalescing_conditional_with_inequality_right(async);

AssertSql(
"""
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
FROM "BasicTypesEntities" AS "b"
WHERE NULLIF("b"."Int", 9) > 1
""");
}

#endregion Uncoalescing conditional / NullIf

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down

0 comments on commit 70b2aa5

Please sign in to comment.