Skip to content

Commit

Permalink
Fix optimization of CASE op WHEN (#33869)
Browse files Browse the repository at this point in the history
The CASE op WHEN ... expression was incorrectly optimized as if it were a CASE WHEN expression if the test expressions contained a TRUE value.

This also makes it possible to use the CASE op WHEN expression to avoid duplicating some subexpressions in the translation of the bool to string conversion.

The new translation avoids duplicating sub-expressions and introducing negations.

Fixes #33867
  • Loading branch information
ranma42 authored Jun 3, 2024
1 parent 32f11fc commit ed6213d
Show file tree
Hide file tree
Showing 13 changed files with 166 additions and 48 deletions.
30 changes: 19 additions & 11 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,7 @@ protected virtual SqlExpression VisitAtTimeZone(
/// <returns>An optimized sql expression.</returns>
protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool allowOptimizedExpansion, out bool nullable)
{
// if there is no 'else' there is a possibility of null, when none of the conditions are met
// otherwise the result is nullable if any of the WhenClause results OR ElseResult is nullable
nullable = caseExpression.ElseResult == null;
nullable = false;
var currentNonNullableColumnsCount = _nonNullableColumns.Count;
var currentNullValueColumnsCount = _nullValueColumns.Count;

Expand All @@ -515,11 +513,16 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
var test = Visit(
whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (IsTrue(test))
var testCondition = testIsCondition
? test
: Visit(_sqlExpressionFactory.Equal(operand!, test),
allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (IsTrue(testCondition))
{
testEvaluatesToTrue = true;
}
else if (IsFalse(test))
else if (IsFalse(testCondition))
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
Expand All @@ -538,6 +541,12 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
// if test evaluates to 'true' we can remove every condition that comes after, including ElseResult
if (testEvaluatesToTrue)
{
// if the first When clause is always satisfied, simply return its result
if (whenClauses.Count == 1)
{
return whenClauses[0].Result;
}

break;
}
}
Expand All @@ -547,6 +556,10 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
{
elseResult = Visit(caseExpression.ElseResult, out var elseResultNullable);
nullable |= elseResultNullable;

// if there is no 'else' there is a possibility of null, when none of the conditions are met
// otherwise the result is nullable if any of the WhenClause results OR ElseResult is nullable
nullable |= elseResult == null;
}

RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
Expand All @@ -560,12 +573,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
return elseResult ?? _sqlExpressionFactory.Constant(null, caseExpression.Type, caseExpression.TypeMapping);
}

// if there is only one When clause and it's test evaluates to 'true' AND there is no else block, simply return the result
return elseResult == null
&& whenClauses.Count == 1
&& IsTrue(whenClauses[0].Test)
? whenClauses[0].Result
: caseExpression.Update(operand, whenClauses, elseResult);
return caseExpression.Update(operand, whenClauses, elseResult);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ public SqlServerObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFact
if (instance is ColumnExpression { IsNullable: true })
{
return _sqlExpressionFactory.Case(
instance,
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false),
_sqlExpressionFactory.Constant(false.ToString())),
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(true)),
_sqlExpressionFactory.Constant(true),
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(null, typeof(string)));
Expand All @@ -98,10 +99,10 @@ public SqlServerObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFact
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false.ToString()))
instance,
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(true.ToString()));
_sqlExpressionFactory.Constant(false.ToString()));
}

return TypeMapping.TryGetValue(instance.Type, out var storeType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory
if (instance is ColumnExpression { IsNullable: true })
{
return _sqlExpressionFactory.Case(
instance,
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false),
_sqlExpressionFactory.Constant(false.ToString())),
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(true)),
_sqlExpressionFactory.Constant(true),
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(null, typeof(string)));
Expand All @@ -93,10 +94,10 @@ public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false.ToString()))
instance,
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(true.ToString()));
_sqlExpressionFactory.Constant(false.ToString()));
}

return TypeMapping.Contains(instance.Type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,28 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
new CaseWhenClause(args[4], args[5]),
]))
);

modelBuilder.HasDbFunction(
typeof(NullSemanticsQueryFixtureBase).GetMethod(nameof(BoolSwitch)),
b => b.HasTranslation(args => new CaseExpression(
operand: args[0],
[
new CaseWhenClause(new SqlConstantExpression(true, typeMapping: BoolTypeMapping.Default), args[1]),
new CaseWhenClause(new SqlConstantExpression(false, typeMapping: BoolTypeMapping.Default), args[2]),
]))
);
}

public static int? Cases(bool c1, int v1, bool c2, int v2, bool c3, int v3) =>
c1 ? v1 :
c2 ? v2 :
c3 ? v3 :
null;

public static int BoolSwitch(bool x, int whenTrue, int whenFalse) =>
x switch
{
true => whenTrue,
false => whenFalse,
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,31 @@ public virtual Task CaseWhen_equal_to_first_or_third_select(bool async)
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseOpWhen_projection(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.OrderBy(x => x.Id)
.Select(x => NullSemanticsQueryFixtureBase.BoolSwitch(
x.StringA == "Foo", 3, 2
)),
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseOpWhen_predicate(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.Where(x => NullSemanticsQueryFixtureBase.BoolSwitch(
x.StringA == "Foo", 3, 2
) == 2),
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task False_compared_to_negated_is_null(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4007,8 +4007,8 @@ public override async Task ToString_boolean_property_non_nullable(bool async)
AssertSql(
"""
SELECT CASE
WHEN [w].[IsAutomatic] = CAST(0 AS bit) THEN N'False'
ELSE N'True'
WHEN [w].[IsAutomatic] = CAST(1 AS bit) THEN N'True'
ELSE N'False'
END
FROM [Weapons] AS [w]
""");
Expand All @@ -4020,9 +4020,9 @@ public override async Task ToString_boolean_property_nullable(bool async)

AssertSql(
"""
SELECT CASE
WHEN [f].[Eradicated] = CAST(0 AS bit) THEN N'False'
WHEN [f].[Eradicated] = CAST(1 AS bit) THEN N'True'
SELECT CASE [f].[Eradicated]
WHEN CAST(0 AS bit) THEN N'False'
WHEN CAST(1 AS bit) THEN N'True'
ELSE NULL
END
FROM [Factions] AS [f]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ FROM OPENJSON([j].[OwnedCollectionRoot], '$') AS [o]
OUTER APPLY (
SELECT [j].[Id], CAST(JSON_VALUE([o0].[value], '$.Date') AS datetime2) AS [Date], CAST(JSON_VALUE([o0].[value], '$.Enum') AS int) AS [Enum], JSON_QUERY([o0].[value], '$.Enums') AS [Enums], CAST(JSON_VALUE([o0].[value], '$.Fraction') AS decimal(18,2)) AS [Fraction], CAST(JSON_VALUE([o0].[value], '$.NullableEnum') AS int) AS [NullableEnum], JSON_QUERY([o0].[value], '$.NullableEnums') AS [NullableEnums], JSON_QUERY([o0].[value], '$.OwnedCollectionLeaf') AS [c], JSON_QUERY([o0].[value], '$.OwnedReferenceLeaf') AS [c0], [o0].[key], CAST([o0].[key] AS int) AS [c1]
FROM OPENJSON(JSON_QUERY([o].[value], '$.OwnedCollectionBranch'), '$') AS [o0]
WHERE CAST(JSON_VALUE([o0].[value], '$.Date') AS datetime2) <> '2000-01-01T00:00:00.0000000' OR CAST(JSON_VALUE([o0].[value], '$.Date') AS datetime2) IS NULL
WHERE CAST(JSON_VALUE([o0].[value], '$.Date') AS datetime2) <> '2000-01-01T00:00:00.0000000'
) AS [o1]
) AS [s]
ORDER BY [j].[Id], [s].[c1], [s].[key], [s].[c10]
Expand Down Expand Up @@ -1373,7 +1373,7 @@ FROM OPENJSON([j].[OwnedCollectionRoot], '$') AS [o2]
OUTER APPLY (
SELECT [j].[Id], CAST(JSON_VALUE([o3].[value], '$.Date') AS datetime2) AS [Date], CAST(JSON_VALUE([o3].[value], '$.Enum') AS int) AS [Enum], JSON_QUERY([o3].[value], '$.Enums') AS [Enums], CAST(JSON_VALUE([o3].[value], '$.Fraction') AS decimal(18,2)) AS [Fraction], CAST(JSON_VALUE([o3].[value], '$.NullableEnum') AS int) AS [NullableEnum], JSON_QUERY([o3].[value], '$.NullableEnums') AS [NullableEnums], JSON_QUERY([o3].[value], '$.OwnedCollectionLeaf') AS [c], JSON_QUERY([o3].[value], '$.OwnedReferenceLeaf') AS [c0], [o3].[key], CAST([o3].[key] AS int) AS [c1]
FROM OPENJSON(JSON_QUERY([o2].[value], '$.OwnedCollectionBranch'), '$') AS [o3]
WHERE CAST(JSON_VALUE([o3].[value], '$.Date') AS datetime2) <> '2000-01-01T00:00:00.0000000' OR CAST(JSON_VALUE([o3].[value], '$.Date') AS datetime2) IS NULL
WHERE CAST(JSON_VALUE([o3].[value], '$.Date') AS datetime2) <> '2000-01-01T00:00:00.0000000'
) AS [o5]
) AS [s]
LEFT JOIN [JsonEntitiesBasicForCollection] AS [j0] ON [j].[Id] = [j0].[ParentId]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2860,6 +2860,42 @@ ORDER BY [e].[Id]
""");
}

public override async Task CaseOpWhen_projection(bool async)
{
await base.CaseOpWhen_projection(async);

AssertSql(
"""
SELECT CASE CASE
WHEN [e].[StringA] = N'Foo' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
WHEN 1 THEN 3
WHEN 0 THEN 2
END
FROM [Entities1] AS [e]
ORDER BY [e].[Id]
""");
}

public override async Task CaseOpWhen_predicate(bool async)
{
await base.CaseOpWhen_predicate(async);

AssertSql(
"""
SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE CASE CASE
WHEN [e].[StringA] = N'Foo' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
WHEN 1 THEN 3
WHEN 0 THEN 2
END = 2
""");
}

public override async Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
await base.Multiple_non_equality_comparisons_with_null_in_the_middle(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12632,9 +12632,9 @@ public override async Task ToString_boolean_property_nullable(bool async)

AssertSql(
"""
SELECT CASE
WHEN [l].[Eradicated] = CAST(0 AS bit) THEN N'False'
WHEN [l].[Eradicated] = CAST(1 AS bit) THEN N'True'
SELECT CASE [l].[Eradicated]
WHEN CAST(0 AS bit) THEN N'False'
WHEN CAST(1 AS bit) THEN N'True'
ELSE NULL
END
FROM [LocustHordes] AS [l]
Expand Down Expand Up @@ -12693,8 +12693,8 @@ public override async Task ToString_boolean_property_non_nullable(bool async)
AssertSql(
"""
SELECT CASE
WHEN [w].[IsAutomatic] = CAST(0 AS bit) THEN N'False'
ELSE N'True'
WHEN [w].[IsAutomatic] = CAST(1 AS bit) THEN N'True'
ELSE N'False'
END
FROM [Weapons] AS [w]
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10797,9 +10797,9 @@ public override async Task ToString_boolean_property_nullable(bool async)

AssertSql(
"""
SELECT CASE
WHEN [l].[Eradicated] = CAST(0 AS bit) THEN N'False'
WHEN [l].[Eradicated] = CAST(1 AS bit) THEN N'True'
SELECT CASE [l].[Eradicated]
WHEN CAST(0 AS bit) THEN N'False'
WHEN CAST(1 AS bit) THEN N'True'
ELSE NULL
END
FROM [Factions] AS [f]
Expand Down Expand Up @@ -10853,8 +10853,8 @@ public override async Task ToString_boolean_property_non_nullable(bool async)
AssertSql(
"""
SELECT CASE
WHEN [w].[IsAutomatic] = CAST(0 AS bit) THEN N'False'
ELSE N'True'
WHEN [w].[IsAutomatic] = CAST(1 AS bit) THEN N'True'
ELSE N'False'
END
FROM [Weapons] AS [w]
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8259,9 +8259,9 @@ public override async Task ToString_boolean_property_nullable(bool async)

AssertSql(
"""
SELECT CASE
WHEN [f].[Eradicated] = CAST(0 AS bit) THEN N'False'
WHEN [f].[Eradicated] = CAST(1 AS bit) THEN N'True'
SELECT CASE [f].[Eradicated]
WHEN CAST(0 AS bit) THEN N'False'
WHEN CAST(1 AS bit) THEN N'True'
ELSE NULL
END
FROM [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f]
Expand Down Expand Up @@ -9192,8 +9192,8 @@ public override async Task ToString_boolean_property_non_nullable(bool async)
AssertSql(
"""
SELECT CASE
WHEN [w].[IsAutomatic] = CAST(0 AS bit) THEN N'False'
ELSE N'True'
WHEN [w].[IsAutomatic] = CAST(1 AS bit) THEN N'True'
ELSE N'False'
END
FROM [Weapons] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [w]
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3187,8 +3187,8 @@ public override async Task ToString_boolean_property_non_nullable(bool async)
AssertSql(
"""
SELECT CASE
WHEN NOT ("w"."IsAutomatic") THEN 'False'
ELSE 'True'
WHEN "w"."IsAutomatic" THEN 'True'
ELSE 'False'
END
FROM "Weapons" AS "w"
""");
Expand Down Expand Up @@ -5994,9 +5994,9 @@ public override async Task ToString_boolean_property_nullable(bool async)

AssertSql(
"""
SELECT CASE
WHEN "f"."Eradicated" = 0 THEN 'False'
WHEN "f"."Eradicated" THEN 'True'
SELECT CASE "f"."Eradicated"
WHEN 0 THEN 'False'
WHEN 1 THEN 'True'
ELSE NULL
END
FROM "Factions" AS "f"
Expand Down
Loading

0 comments on commit ed6213d

Please sign in to comment.