diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs index 636f5ed13dc..ece9df86d97 100644 --- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs @@ -149,26 +149,71 @@ private bool PreserveConvertNode(Expression expression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitBinary(BinaryExpression binaryExpression) + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { - if (!binaryExpression.IsLogicalOperation()) - { - return base.VisitBinary(binaryExpression); - } + var newTestExpression = TryGetConstantValue(conditionalExpression.Test) ?? Visit(conditionalExpression.Test); - var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left); - if (ShortCircuitBinaryExpression(newLeftExpression, binaryExpression.NodeType)) + if (newTestExpression is ConstantExpression constantTestExpression + && constantTestExpression.Value is bool constantTestValue) { - return newLeftExpression; + return constantTestValue + ? Visit(conditionalExpression.IfTrue) + : Visit(conditionalExpression.IfFalse); } - var newRightExpression = TryGetConstantValue(binaryExpression.Right) ?? Visit(binaryExpression.Right); - if (ShortCircuitBinaryExpression(newRightExpression, binaryExpression.NodeType)) + return conditionalExpression.Update( + newTestExpression, + Visit(conditionalExpression.IfTrue), + Visit(conditionalExpression.IfFalse)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + switch (binaryExpression.NodeType) { - return newRightExpression; - } + case ExpressionType.Coalesce: + { + var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left); + if (newLeftExpression is ConstantExpression constantLeftExpression) + { + return constantLeftExpression.Value == null + ? Visit(binaryExpression.Right) + : newLeftExpression; + } + + return binaryExpression.Update( + newLeftExpression, + binaryExpression.Conversion, + Visit(binaryExpression.Right)); + } + + case ExpressionType.AndAlso: + case ExpressionType.OrElse: + { + var newLeftExpression = TryGetConstantValue(binaryExpression.Left) ?? Visit(binaryExpression.Left); + if (ShortCircuitLogicalExpression(newLeftExpression, binaryExpression.NodeType)) + { + return newLeftExpression; + } + + var newRightExpression = TryGetConstantValue(binaryExpression.Right) ?? Visit(binaryExpression.Right); + if (ShortCircuitLogicalExpression(newRightExpression, binaryExpression.NodeType)) + { + return newRightExpression; + } - return binaryExpression.Update(newLeftExpression, binaryExpression.Conversion, newRightExpression); + return binaryExpression.Update(newLeftExpression, binaryExpression.Conversion, newRightExpression); + } + + default: + return base.VisitBinary(binaryExpression); + } } private Expression TryGetConstantValue(Expression expression) @@ -186,7 +231,7 @@ private Expression TryGetConstantValue(Expression expression) return null; } - private static bool ShortCircuitBinaryExpression(Expression expression, ExpressionType nodeType) + private static bool ShortCircuitLogicalExpression(Expression expression, ExpressionType nodeType) => expression is ConstantExpression constantExpression && constantExpression.Value is bool constantValue && ((constantValue && nodeType == ExpressionType.OrElse) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs index e5effc18e73..c758419fde4 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs @@ -1585,7 +1585,7 @@ public override async Task Where_ternary_boolean_condition_true(bool isAsync) SELECT c FROM root c -WHERE ((c[""Discriminator""] = ""Product"") AND (@__flag_0 ? (c[""UnitsInStock""] >= 20) : (c[""UnitsInStock""] < 20)))"); +WHERE ((c[""Discriminator""] = ""Product"") AND (c[""UnitsInStock""] >= 20))"); } public override async Task Where_ternary_boolean_condition_false(bool isAsync) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs index 0eba1ebfd2c..11ffce44cf3 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs @@ -2041,7 +2041,7 @@ public override async Task OrderBy_conditional_operator_where_condition_false(bo SELECT c FROM root c WHERE (c[""Discriminator""] = ""Customer"") -ORDER BY (@__p_0 ? ""ZZ"" : c[""City""])"); +ORDER BY c[""City""]"); } [ConditionalTheory(Skip = "Issue #17246")] diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 94c571c43b7..97a72be4aed 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -824,6 +824,24 @@ public virtual Task Ternary_should_not_evaluate_both_sides(bool isAsync) })); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Ternary_should_not_evaluate_both_sides_with_parameter(bool isAsync) + { + DateTime? param = null; + + return AssertQuery( + isAsync, + ss => ss.Set().Select( + o => new + { + // ReSharper disable SimplifyConditionalTernaryExpression + Data1 = param != null ? o.OrderDate == param.Value : true, + Data2 = param == null ? true : o.OrderDate == param.Value + // ReSharper restore SimplifyConditionalTernaryExpression + })); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Null_Coalesce_Short_Circuit(bool isAsync) @@ -833,10 +851,29 @@ public virtual Task Null_Coalesce_Short_Circuit(bool isAsync) return AssertQuery( isAsync, - ss => ss.Set().Distinct().Select(c => new { Customer = c, Test = (test ?? values.Contains(1)) }), + ss => ss.Set().Distinct().Select(c => new + { + Customer = c, + Test = test ?? values.Contains(1) + }), entryCount: 91); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Null_Coalesce_Short_Circuit_with_server_correlated_leftover(bool isAsync) + { + List values = null; + bool? test = false; + + return AssertQuery( + isAsync, + ss => ss.Set().Select(c => new + { + Result = test ?? values.Select(c2 => c2.CustomerID).Contains(c.CustomerID) + })); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Distinct_Skip_Take(bool isAsync) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 3b84689bbbc..fd40ea0bc79 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -975,29 +975,13 @@ public override void Where_conditional_search_condition_in_result() base.Where_conditional_search_condition_in_result(); AssertSql( - @"@__prm_0='True' - -SELECT [e].[Id] + @"SELECT [e].[Id] FROM [Entities1] AS [e] -WHERE CASE - WHEN @__prm_0 = CAST(1 AS bit) THEN CASE - WHEN [e].[StringA] IN (N'Foo', N'Bar') THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END - ELSE CAST(0 AS bit) -END = CAST(1 AS bit)", +WHERE [e].[StringA] IN (N'Foo', N'Bar')", // - @"@__p_0='False' - -SELECT [e].[Id] + @"SELECT [e].[Id] FROM [Entities1] AS [e] -WHERE CASE - WHEN @__p_0 = CAST(1 AS bit) THEN CAST(1 AS bit) - ELSE CASE - WHEN [e].[StringA] LIKE N'A%' THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END -END = CAST(1 AS bit)"); +WHERE [e].[StringA] LIKE N'A%'"); } public override void Where_nested_conditional_search_condition_in_result() @@ -1005,33 +989,8 @@ public override void Where_nested_conditional_search_condition_in_result() base.Where_nested_conditional_search_condition_in_result(); AssertSql( - @"@__prm1_0='True' -@__prm2_1='False' - -SELECT [e].[Id] -FROM [Entities1] AS [e] -WHERE CASE - WHEN @__prm1_0 = CAST(1 AS bit) THEN CASE - WHEN @__prm2_1 = CAST(1 AS bit) THEN CASE - WHEN [e].[BoolA] = CAST(1 AS bit) THEN CASE - WHEN [e].[StringA] LIKE N'A%' THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END - ELSE CAST(0 AS bit) - END - ELSE CAST(1 AS bit) - END - ELSE CASE - WHEN [e].[BoolB] = CAST(1 AS bit) THEN CASE - WHEN [e].[StringA] IN (N'Foo', N'Bar') THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END - ELSE CASE - WHEN [e].[StringB] IN (N'Foo', N'Bar') THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END - END -END = CAST(1 AS bit)"); + @"SELECT [e].[Id] +FROM [Entities1] AS [e]"); } public override void Where_equal_using_relational_null_semantics() diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs index 7c092752f8d..aaa85ff9cda 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs @@ -1359,20 +1359,9 @@ public override async Task Where_ternary_boolean_condition_true(bool isAsync) await base.Where_ternary_boolean_condition_true(isAsync); AssertSql( - @"@__flag_0='True' - -SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock] + @"SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock] FROM [Products] AS [p] -WHERE CASE - WHEN @__flag_0 = CAST(1 AS bit) THEN CASE - WHEN [p].[UnitsInStock] >= CAST(20 AS smallint) THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END - ELSE CASE - WHEN [p].[UnitsInStock] < CAST(20 AS smallint) THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) - END -END = CAST(1 AS bit)"); +WHERE [p].[UnitsInStock] >= CAST(20 AS smallint)"); } public override async Task Where_ternary_boolean_condition_false(bool isAsync) @@ -1450,6 +1439,15 @@ public override async Task Ternary_should_not_evaluate_both_sides(bool isAsync) FROM [Customers] AS [c]"); } + public override async Task Ternary_should_not_evaluate_both_sides_with_parameter(bool isAsync) + { + await base.Ternary_should_not_evaluate_both_sides_with_parameter(isAsync); + + AssertSql( + @"SELECT CAST(1 AS bit) AS [Data1] +FROM [Orders] AS [o]"); + } + public override async Task Where_compare_constructed_multi_value_equal(bool isAsync) { await base.Where_compare_constructed_multi_value_equal(isAsync); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 84c8615ee35..8a080ace7ca 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -2588,13 +2588,23 @@ public override async Task Null_Coalesce_Short_Circuit(bool isAsync) { await base.Null_Coalesce_Short_Circuit(isAsync); - // issue #15994 -// AssertSql( -// @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region] -//FROM ( -// SELECT DISTINCT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] -// FROM [Customers] AS [c] -//) AS [t]"); + AssertSql( + @"@__p_0='False' + +SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], @__p_0 AS [Test] +FROM ( + SELECT DISTINCT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] +) AS [t]"); + } + + public override async Task Null_Coalesce_Short_Circuit_with_server_correlated_leftover(bool isAsync) + { + await base.Null_Coalesce_Short_Circuit_with_server_correlated_leftover(isAsync); + + AssertSql( + @"SELECT CAST(0 AS bit) AS [Result] +FROM [Customers] AS [c]"); } public override async Task OrderBy_conditional_operator_where_condition_false(bool isAsync) @@ -2602,14 +2612,9 @@ public override async Task OrderBy_conditional_operator_where_condition_false(bo await base.OrderBy_conditional_operator_where_condition_false(isAsync); AssertSql( - @"@__p_0='False' - -SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -ORDER BY CASE - WHEN @__p_0 = CAST(1 AS bit) THEN N'ZZ' - ELSE [c].[City] -END"); +ORDER BY [c].[City]"); } public override async Task OrderBy_comparison_operator(bool isAsync)