Skip to content

Commit

Permalink
Fix conditional and null-coalescing handling in parameter extraction
Browse files Browse the repository at this point in the history
Fixes #17942
Fixes #13859
  • Loading branch information
roji committed Oct 29, 2019
1 parent 45adf14 commit 0cb231a
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 91 deletions.
73 changes: 59 additions & 14 deletions src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,26 +149,71 @@ private bool PreserveConvertNode(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitBinary(BinaryExpression binaryExpression)
protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
if (!binaryExpression.IsLogicalOperation())
{
return base.VisitBinary(binaryExpression);
}
var newTestExpression = TryGetConstantValue(conditionalExpression.Test) ?? Visit(conditionalExpression.Test);

var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left);
if (ShortCircuitBinaryExpression(newLeftExpression, binaryExpression.NodeType))
if (newTestExpression is ConstantExpression constantTestExpression
&& constantTestExpression.Value is bool constantTestValue)
{
return newLeftExpression;
return constantTestValue
? Visit(conditionalExpression.IfTrue)
: Visit(conditionalExpression.IfFalse);
}

var newRightExpression = TryGetConstantValue(binaryExpression.Right) ?? Visit(binaryExpression.Right);
if (ShortCircuitBinaryExpression(newRightExpression, binaryExpression.NodeType))
return conditionalExpression.Update(
newTestExpression,
Visit(conditionalExpression.IfTrue),
Visit(conditionalExpression.IfFalse));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
switch (binaryExpression.NodeType)
{
return newRightExpression;
}
case ExpressionType.Coalesce:
{
var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left);
if (newLeftExpression is ConstantExpression constantLeftExpression)
{
return constantLeftExpression.Value == null
? Visit(binaryExpression.Right)
: newLeftExpression;
}

return binaryExpression.Update(
newLeftExpression,
binaryExpression.Conversion,
Visit(binaryExpression.Right));
}

case ExpressionType.AndAlso:
case ExpressionType.OrElse:
{
var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left);
if (ShortCircuitLogicalExpression(newLeftExpression, binaryExpression.NodeType))
{
return newLeftExpression;
}

var newRightExpression = TryGetConstantValue(binaryExpression.Right) ?? Visit(binaryExpression.Right);
if (ShortCircuitLogicalExpression(newRightExpression, binaryExpression.NodeType))
{
return newRightExpression;
}

return binaryExpression.Update(newLeftExpression, binaryExpression.Conversion, newRightExpression);
return binaryExpression.Update(newLeftExpression, binaryExpression.Conversion, newRightExpression);
}

default:
return base.VisitBinary(binaryExpression);
}
}

private Expression TryGetConstantValue(Expression expression)
Expand All @@ -186,7 +231,7 @@ private Expression TryGetConstantValue(Expression expression)
return null;
}

private static bool ShortCircuitBinaryExpression(Expression expression, ExpressionType nodeType)
private static bool ShortCircuitLogicalExpression(Expression expression, ExpressionType nodeType)
=> expression is ConstantExpression constantExpression
&& constantExpression.Value is bool constantValue
&& ((constantValue && nodeType == ExpressionType.OrElse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ public override async Task Where_ternary_boolean_condition_true(bool isAsync)
SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Product"") AND (@__flag_0 ? (c[""UnitsInStock""] >= 20) : (c[""UnitsInStock""] < 20)))");
WHERE ((c[""Discriminator""] = ""Product"") AND (c[""UnitsInStock""] >= 20))");
}

public override async Task Where_ternary_boolean_condition_false(bool isAsync)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ public override async Task OrderBy_conditional_operator_where_condition_false(bo
SELECT c
FROM root c
WHERE (c[""Discriminator""] = ""Customer"")
ORDER BY (@__p_0 ? ""ZZ"" : c[""City""])");
ORDER BY c[""City""]");
}

[ConditionalTheory(Skip = "Issue #17246")]
Expand Down
39 changes: 38 additions & 1 deletion test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,24 @@ public virtual Task Ternary_should_not_evaluate_both_sides(bool isAsync)
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Ternary_should_not_evaluate_both_sides_with_parameter(bool isAsync)
{
DateTime? param = null;

return AssertQuery(
isAsync,
ss => ss.Set<Order>().Select(
o => new
{
// ReSharper disable SimplifyConditionalTernaryExpression
Data1 = param != null ? o.OrderDate == param.Value : true,
Data2 = param == null ? true : o.OrderDate == param.Value
// ReSharper restore SimplifyConditionalTernaryExpression
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_Coalesce_Short_Circuit(bool isAsync)
Expand All @@ -833,10 +851,29 @@ public virtual Task Null_Coalesce_Short_Circuit(bool isAsync)

return AssertQuery(
isAsync,
ss => ss.Set<Customer>().Distinct().Select(c => new { Customer = c, Test = (test ?? values.Contains(1)) }),
ss => ss.Set<Customer>().Distinct().Select(c => new
{
Customer = c,
Test = test ?? values.Contains(1)
}),
entryCount: 91);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_Coalesce_Short_Circuit_with_server_correlated_leftover(bool isAsync)
{
List<Customer> values = null;
bool? test = false;

return AssertQuery(
isAsync,
ss => ss.Set<Customer>().Select(c => new
{
Result = test ?? values.Select(c2 => c2.CustomerID).Contains(c.CustomerID)
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Distinct_Skip_Take(bool isAsync)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,63 +975,22 @@ public override void Where_conditional_search_condition_in_result()
base.Where_conditional_search_condition_in_result();

AssertSql(
@"@__prm_0='True'
SELECT [e].[Id]
@"SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN @__prm_0 = CAST(1 AS bit) THEN CASE
WHEN [e].[StringA] IN (N'Foo', N'Bar') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CAST(0 AS bit)
END = CAST(1 AS bit)",
WHERE [e].[StringA] IN (N'Foo', N'Bar')",
//
@"@__p_0='False'
SELECT [e].[Id]
@"SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN @__p_0 = CAST(1 AS bit) THEN CAST(1 AS bit)
ELSE CASE
WHEN [e].[StringA] LIKE N'A%' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END = CAST(1 AS bit)");
WHERE [e].[StringA] LIKE N'A%'");
}

public override void Where_nested_conditional_search_condition_in_result()
{
base.Where_nested_conditional_search_condition_in_result();

AssertSql(
@"@__prm1_0='True'
@__prm2_1='False'
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN @__prm1_0 = CAST(1 AS bit) THEN CASE
WHEN @__prm2_1 = CAST(1 AS bit) THEN CASE
WHEN [e].[BoolA] = CAST(1 AS bit) THEN CASE
WHEN [e].[StringA] LIKE N'A%' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CAST(0 AS bit)
END
ELSE CAST(1 AS bit)
END
ELSE CASE
WHEN [e].[BoolB] = CAST(1 AS bit) THEN CASE
WHEN [e].[StringA] IN (N'Foo', N'Bar') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CASE
WHEN [e].[StringB] IN (N'Foo', N'Bar') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END
END = CAST(1 AS bit)");
@"SELECT [e].[Id]
FROM [Entities1] AS [e]");
}

public override void Where_equal_using_relational_null_semantics()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,20 +1359,9 @@ public override async Task Where_ternary_boolean_condition_true(bool isAsync)
await base.Where_ternary_boolean_condition_true(isAsync);

AssertSql(
@"@__flag_0='True'
SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock]
@"SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock]
FROM [Products] AS [p]
WHERE CASE
WHEN @__flag_0 = CAST(1 AS bit) THEN CASE
WHEN [p].[UnitsInStock] >= CAST(20 AS smallint) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CASE
WHEN [p].[UnitsInStock] < CAST(20 AS smallint) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END = CAST(1 AS bit)");
WHERE [p].[UnitsInStock] >= CAST(20 AS smallint)");
}

public override async Task Where_ternary_boolean_condition_false(bool isAsync)
Expand Down Expand Up @@ -1450,6 +1439,15 @@ public override async Task Ternary_should_not_evaluate_both_sides(bool isAsync)
FROM [Customers] AS [c]");
}

public override async Task Ternary_should_not_evaluate_both_sides_with_parameter(bool isAsync)
{
await base.Ternary_should_not_evaluate_both_sides_with_parameter(isAsync);

AssertSql(
@"SELECT CAST(1 AS bit) AS [Data1]
FROM [Orders] AS [o]");
}

public override async Task Where_compare_constructed_multi_value_equal(bool isAsync)
{
await base.Where_compare_constructed_multi_value_equal(isAsync);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2588,28 +2588,33 @@ public override async Task Null_Coalesce_Short_Circuit(bool isAsync)
{
await base.Null_Coalesce_Short_Circuit(isAsync);

// issue #15994
// AssertSql(
// @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region]
//FROM (
// SELECT DISTINCT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
// FROM [Customers] AS [c]
//) AS [t]");
AssertSql(
@"@__p_0='False'
SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], @__p_0 AS [Test]
FROM (
SELECT DISTINCT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
) AS [t]");
}

public override async Task Null_Coalesce_Short_Circuit_with_server_correlated_leftover(bool isAsync)
{
await base.Null_Coalesce_Short_Circuit_with_server_correlated_leftover(isAsync);

AssertSql(
@"SELECT CAST(0 AS bit) AS [Result]
FROM [Customers] AS [c]");
}

public override async Task OrderBy_conditional_operator_where_condition_false(bool isAsync)
{
await base.OrderBy_conditional_operator_where_condition_false(isAsync);

AssertSql(
@"@__p_0='False'
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
ORDER BY CASE
WHEN @__p_0 = CAST(1 AS bit) THEN N'ZZ'
ELSE [c].[City]
END");
ORDER BY [c].[City]");
}

public override async Task OrderBy_comparison_operator(bool isAsync)
Expand Down

0 comments on commit 0cb231a

Please sign in to comment.