Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optimization of CASE op WHEN #33869

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,16 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
var test = Visit(
whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (IsTrue(test))
var testCondition = testIsCondition
? test
: Visit(_sqlExpressionFactory.Equal(operand!, test),
allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (IsTrue(testCondition))
{
testEvaluatesToTrue = true;
}
else if (IsFalse(test))
else if (IsFalse(testCondition))
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
Expand All @@ -538,6 +543,12 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
// if test evaluates to 'true' we can remove every condition that comes after, including ElseResult
if (testEvaluatesToTrue)
{
// if the first When clause is always satisfied, simply return its result
if (whenClauses.Count == 1)
{
return whenClauses[0].Result;
}

break;
}
}
Expand All @@ -560,12 +571,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
return elseResult ?? _sqlExpressionFactory.Constant(null, caseExpression.Type, caseExpression.TypeMapping);
}

// if there is only one When clause and it's test evaluates to 'true' AND there is no else block, simply return the result
return elseResult == null
&& whenClauses.Count == 1
&& IsTrue(whenClauses[0].Test)
? whenClauses[0].Result
: caseExpression.Update(operand, whenClauses, elseResult);
return caseExpression.Update(operand, whenClauses, elseResult);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ public SqlServerObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFact
if (instance is ColumnExpression { IsNullable: true })
{
return _sqlExpressionFactory.Case(
instance,
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false),
_sqlExpressionFactory.Constant(false.ToString())),
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(true)),
_sqlExpressionFactory.Constant(true),
_sqlExpressionFactory.Constant(true.ToString()))
Comment on lines 84 to 93
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same approach recommended by @roji in #33706 (comment) 🚀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will let @maumar review/approve but this looks great... I remember this form of CASE/WHEN (with an operand) was introduced a bit later back in the day, I don't think we were aware of it originally; so I'm not surprised we have the less efficient other variant in various places - may be worth doing a pass over the code base for those.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do the full sweep of our code, but that can be done independently of this PR

},
_sqlExpressionFactory.Constant(null, typeof(string)));
Expand All @@ -98,10 +99,10 @@ public SqlServerObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFact
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false.ToString()))
instance,
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(true.ToString()));
_sqlExpressionFactory.Constant(false.ToString()));
}

return TypeMapping.TryGetValue(instance.Type, out var storeType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory
if (instance is ColumnExpression { IsNullable: true })
{
return _sqlExpressionFactory.Case(
instance,
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false),
_sqlExpressionFactory.Constant(false.ToString())),
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(true)),
_sqlExpressionFactory.Constant(true),
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(null, typeof(string)));
Expand All @@ -93,10 +94,10 @@ public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(false)),
_sqlExpressionFactory.Constant(false.ToString()))
instance,
_sqlExpressionFactory.Constant(true.ToString()))
},
_sqlExpressionFactory.Constant(true.ToString()));
_sqlExpressionFactory.Constant(false.ToString()));
}

return TypeMapping.Contains(instance.Type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,28 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
new CaseWhenClause(args[4], args[5]),
]))
);

modelBuilder.HasDbFunction(
typeof(NullSemanticsQueryFixtureBase).GetMethod(nameof(BoolSwitch)),
b => b.HasTranslation(args => new CaseExpression(
operand: args[0],
[
new CaseWhenClause(new SqlConstantExpression(true, typeMapping: BoolTypeMapping.Default), args[1]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ranma42 while bringing the PG provider up to date, this fails because BoolTypeMapping doesn't work there - its literal representation is 1/0, whereas PG has a true boolean type with TRUE/FALSE as its literals. It's no big deal - I'm overriding the definition to use NpgsqlBoolTypeMapping.

If anything, this shows the shortcomings of our current HasTranslation API; users shouldn't need to manually deal with type mappings like this, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ouch, sorry; is there a simple way to make it more portable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can always expose some overridable method for providers to construct their boolean type mapping, but honestly it isn't worth it... They can just override the function definition (as I've done).

new CaseWhenClause(new SqlConstantExpression(false, typeMapping: BoolTypeMapping.Default), args[2]),
]))
);
}

public static int? Cases(bool c1, int v1, bool c2, int v2, bool c3, int v3) =>
c1 ? v1 :
c2 ? v2 :
c3 ? v3 :
null;

public static int BoolSwitch(bool x, int whenTrue, int whenFalse) =>
x switch
{
true => whenTrue,
false => whenFalse,
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,31 @@ public virtual Task CaseWhen_equal_to_first_or_third_select(bool async)
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseOpWhen_projection(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.OrderBy(x => x.Id)
.Select(x => NullSemanticsQueryFixtureBase.BoolSwitch(
x.StringA == "Foo", 3, 2
)),
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseOpWhen_predicate(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.Where(x => NullSemanticsQueryFixtureBase.BoolSwitch(
x.StringA == "Foo", 3, 2
) == 2),
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task False_compared_to_negated_is_null(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2860,6 +2860,42 @@ ORDER BY [e].[Id]
""");
}

public override async Task CaseOpWhen_projection(bool async)
{
await base.CaseOpWhen_projection(async);

AssertSql(
"""
SELECT CASE CASE
WHEN [e].[StringA] = N'Foo' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
WHEN 1 THEN 3
WHEN 0 THEN 2
END
FROM [Entities1] AS [e]
ORDER BY [e].[Id]
""");
}

public override async Task CaseOpWhen_predicate(bool async)
{
await base.CaseOpWhen_predicate(async);

AssertSql(
"""
SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE CASE CASE
WHEN [e].[StringA] = N'Foo' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
WHEN 1 THEN 3
WHEN 0 THEN 2
END = 2
""");
}

public override async Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
await base.Multiple_non_equality_comparisons_with_null_in_the_middle(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,36 @@ ORDER BY "e"."Id"
""");
}

public override async Task CaseOpWhen_projection(bool async)
{
await base.CaseOpWhen_projection(async);

AssertSql(
"""
SELECT CASE "e"."StringA" = 'Foo'
WHEN 1 THEN 3
WHEN 0 THEN 2
END
FROM "Entities1" AS "e"
ORDER BY "e"."Id"
""");
}

public override async Task CaseOpWhen_predicate(bool async)
{
await base.CaseOpWhen_predicate(async);

AssertSql(
"""
SELECT "e"."Id", "e"."BoolA", "e"."BoolB", "e"."BoolC", "e"."IntA", "e"."IntB", "e"."IntC", "e"."NullableBoolA", "e"."NullableBoolB", "e"."NullableBoolC", "e"."NullableIntA", "e"."NullableIntB", "e"."NullableIntC", "e"."NullableStringA", "e"."NullableStringB", "e"."NullableStringC", "e"."StringA", "e"."StringB", "e"."StringC"
FROM "Entities1" AS "e"
WHERE CASE "e"."StringA" = 'Foo'
WHEN 1 THEN 3
WHEN 0 THEN 2
END = 2
""");
}

public override async Task Bool_equal_nullable_bool_HasValue(bool async)
{
await base.Bool_equal_nullable_bool_HasValue(async);
Expand Down