Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conditional and null-coalescing handling in parameter extraction #18640

Merged
merged 1 commit into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,26 +149,71 @@ private bool PreserveConvertNode(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitBinary(BinaryExpression binaryExpression)
protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
if (!binaryExpression.IsLogicalOperation())
{
return base.VisitBinary(binaryExpression);
}
var newTestExpression = TryGetConstantValue(conditionalExpression.Test) ?? Visit(conditionalExpression.Test);

var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left);
if (ShortCircuitBinaryExpression(newLeftExpression, binaryExpression.NodeType))
if (newTestExpression is ConstantExpression constantTestExpression
&& constantTestExpression.Value is bool constantTestValue)
{
return newLeftExpression;
return constantTestValue
? Visit(conditionalExpression.IfTrue)
: Visit(conditionalExpression.IfFalse);
}

var newRightExpression = TryGetConstantValue(binaryExpression.Right) ?? Visit(binaryExpression.Right);
if (ShortCircuitBinaryExpression(newRightExpression, binaryExpression.NodeType))
return conditionalExpression.Update(
newTestExpression,
Visit(conditionalExpression.IfTrue),
Visit(conditionalExpression.IfFalse));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
switch (binaryExpression.NodeType)
{
return newRightExpression;
}
case ExpressionType.Coalesce:
{
var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left);
if (newLeftExpression is ConstantExpression constantLeftExpression)
{
return constantLeftExpression.Value == null
? Visit(binaryExpression.Right)
: newLeftExpression;
}

return binaryExpression.Update(
newLeftExpression,
binaryExpression.Conversion,
Visit(binaryExpression.Right));
}

case ExpressionType.AndAlso:
case ExpressionType.OrElse:
{
var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left);
if (ShortCircuitLogicalExpression(newLeftExpression, binaryExpression.NodeType))
{
return newLeftExpression;
}

var newRightExpression = TryGetConstantValue(binaryExpression.Right) ?? Visit(binaryExpression.Right);
if (ShortCircuitLogicalExpression(newRightExpression, binaryExpression.NodeType))
{
return newRightExpression;
}

return binaryExpression.Update(newLeftExpression, binaryExpression.Conversion, newRightExpression);
return binaryExpression.Update(newLeftExpression, binaryExpression.Conversion, newRightExpression);
}

default:
return base.VisitBinary(binaryExpression);
}
}

private Expression TryGetConstantValue(Expression expression)
Expand All @@ -186,7 +231,7 @@ private Expression TryGetConstantValue(Expression expression)
return null;
}

private static bool ShortCircuitBinaryExpression(Expression expression, ExpressionType nodeType)
private static bool ShortCircuitLogicalExpression(Expression expression, ExpressionType nodeType)
=> expression is ConstantExpression constantExpression
&& constantExpression.Value is bool constantValue
&& ((constantValue && nodeType == ExpressionType.OrElse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ public override async Task Where_ternary_boolean_condition_true(bool isAsync)

SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Product"") AND (@__flag_0 ? (c[""UnitsInStock""] >= 20) : (c[""UnitsInStock""] < 20)))");
WHERE ((c[""Discriminator""] = ""Product"") AND (c[""UnitsInStock""] >= 20))");
}

public override async Task Where_ternary_boolean_condition_false(bool isAsync)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ public override async Task OrderBy_conditional_operator_where_condition_false(bo
SELECT c
FROM root c
WHERE (c[""Discriminator""] = ""Customer"")
ORDER BY (@__p_0 ? ""ZZ"" : c[""City""])");
ORDER BY c[""City""]");
}

[ConditionalTheory(Skip = "Issue #17246")]
Expand Down
39 changes: 38 additions & 1 deletion test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,24 @@ public virtual Task Ternary_should_not_evaluate_both_sides(bool isAsync)
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Ternary_should_not_evaluate_both_sides_with_parameter(bool isAsync)
{
DateTime? param = null;

return AssertQuery(
isAsync,
ss => ss.Set<Order>().Select(
o => new
{
// ReSharper disable SimplifyConditionalTernaryExpression
Data1 = param != null ? o.OrderDate == param.Value : true,
Data2 = param == null ? true : o.OrderDate == param.Value
// ReSharper restore SimplifyConditionalTernaryExpression
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_Coalesce_Short_Circuit(bool isAsync)
Expand All @@ -833,10 +851,29 @@ public virtual Task Null_Coalesce_Short_Circuit(bool isAsync)

return AssertQuery(
isAsync,
ss => ss.Set<Customer>().Distinct().Select(c => new { Customer = c, Test = (test ?? values.Contains(1)) }),
ss => ss.Set<Customer>().Distinct().Select(c => new
{
Customer = c,
Test = test ?? values.Contains(1)
}),
entryCount: 91);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_Coalesce_Short_Circuit_with_server_correlated_leftover(bool isAsync)
{
List<Customer> values = null;
bool? test = false;

return AssertQuery(
isAsync,
ss => ss.Set<Customer>().Select(c => new
{
Result = test ?? values.Select(c2 => c2.CustomerID).Contains(c.CustomerID)
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Distinct_Skip_Take(bool isAsync)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,63 +975,22 @@ public override void Where_conditional_search_condition_in_result()
base.Where_conditional_search_condition_in_result();

AssertSql(
@"@__prm_0='True'

SELECT [e].[Id]
@"SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN @__prm_0 = CAST(1 AS bit) THEN CASE
WHEN [e].[StringA] IN (N'Foo', N'Bar') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CAST(0 AS bit)
END = CAST(1 AS bit)",
WHERE [e].[StringA] IN (N'Foo', N'Bar')",
//
@"@__p_0='False'

SELECT [e].[Id]
@"SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN @__p_0 = CAST(1 AS bit) THEN CAST(1 AS bit)
ELSE CASE
WHEN [e].[StringA] LIKE N'A%' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END = CAST(1 AS bit)");
WHERE [e].[StringA] LIKE N'A%'");
}

public override void Where_nested_conditional_search_condition_in_result()
{
base.Where_nested_conditional_search_condition_in_result();

AssertSql(
@"@__prm1_0='True'
@__prm2_1='False'

SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN @__prm1_0 = CAST(1 AS bit) THEN CASE
WHEN @__prm2_1 = CAST(1 AS bit) THEN CASE
WHEN [e].[BoolA] = CAST(1 AS bit) THEN CASE
WHEN [e].[StringA] LIKE N'A%' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CAST(0 AS bit)
END
ELSE CAST(1 AS bit)
END
ELSE CASE
WHEN [e].[BoolB] = CAST(1 AS bit) THEN CASE
WHEN [e].[StringA] IN (N'Foo', N'Bar') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CASE
WHEN [e].[StringB] IN (N'Foo', N'Bar') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END
END = CAST(1 AS bit)");
@"SELECT [e].[Id]
FROM [Entities1] AS [e]");
}

public override void Where_equal_using_relational_null_semantics()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,20 +1359,9 @@ public override async Task Where_ternary_boolean_condition_true(bool isAsync)
await base.Where_ternary_boolean_condition_true(isAsync);

AssertSql(
@"@__flag_0='True'

SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock]
@"SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock]
FROM [Products] AS [p]
WHERE CASE
WHEN @__flag_0 = CAST(1 AS bit) THEN CASE
WHEN [p].[UnitsInStock] >= CAST(20 AS smallint) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
ELSE CASE
WHEN [p].[UnitsInStock] < CAST(20 AS smallint) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END = CAST(1 AS bit)");
WHERE [p].[UnitsInStock] >= CAST(20 AS smallint)");
}

public override async Task Where_ternary_boolean_condition_false(bool isAsync)
Expand Down Expand Up @@ -1450,6 +1439,15 @@ public override async Task Ternary_should_not_evaluate_both_sides(bool isAsync)
FROM [Customers] AS [c]");
}

public override async Task Ternary_should_not_evaluate_both_sides_with_parameter(bool isAsync)
{
await base.Ternary_should_not_evaluate_both_sides_with_parameter(isAsync);

AssertSql(
@"SELECT CAST(1 AS bit) AS [Data1]
FROM [Orders] AS [o]");
}

public override async Task Where_compare_constructed_multi_value_equal(bool isAsync)
{
await base.Where_compare_constructed_multi_value_equal(isAsync);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2588,28 +2588,33 @@ public override async Task Null_Coalesce_Short_Circuit(bool isAsync)
{
await base.Null_Coalesce_Short_Circuit(isAsync);

// issue #15994
// AssertSql(
// @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region]
//FROM (
// SELECT DISTINCT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
// FROM [Customers] AS [c]
//) AS [t]");
AssertSql(
@"@__p_0='False'

SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], @__p_0 AS [Test]
FROM (
SELECT DISTINCT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
) AS [t]");
}

public override async Task Null_Coalesce_Short_Circuit_with_server_correlated_leftover(bool isAsync)
{
await base.Null_Coalesce_Short_Circuit_with_server_correlated_leftover(isAsync);

AssertSql(
@"SELECT CAST(0 AS bit) AS [Result]
FROM [Customers] AS [c]");
}

public override async Task OrderBy_conditional_operator_where_condition_false(bool isAsync)
{
await base.OrderBy_conditional_operator_where_condition_false(isAsync);

AssertSql(
@"@__p_0='False'

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
ORDER BY CASE
WHEN @__p_0 = CAST(1 AS bit) THEN N'ZZ'
ELSE [c].[City]
END");
ORDER BY [c].[City]");
}

public override async Task OrderBy_comparison_operator(bool isAsync)
Expand Down