diff --git a/src/EFCore/Query/Internal/InvocationExpressionRemovingExpressionVisitor.cs b/src/EFCore/Query/Internal/InvocationExpressionRemovingExpressionVisitor.cs new file mode 100644 index 00000000000..960cb76b114 --- /dev/null +++ b/src/EFCore/Query/Internal/InvocationExpressionRemovingExpressionVisitor.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class InvocationExpressionRemovingExpressionVisitor : ExpressionVisitor + { + protected override Expression VisitInvocation(InvocationExpression invocationExpression) + { + var invokedExpression = StripTrivialConversions(invocationExpression.Expression); + + return invokedExpression is LambdaExpression lambdaExpression + ? InlineLambdaExpression(lambdaExpression, invocationExpression.Arguments) + : base.VisitInvocation(invocationExpression); + } + + private Expression StripTrivialConversions(Expression expression) + { + while (expression is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert + && expression.Type == unaryExpression.Operand.Type + && unaryExpression.Method == null) + { + expression = unaryExpression.Operand; + } + + return expression; + } + + private Expression InlineLambdaExpression(LambdaExpression lambdaExpression, ReadOnlyCollection arguments) + { + var replacements = new Dictionary(arguments.Count); + for (var i = 0; i < lambdaExpression.Parameters.Count; i++) + { + replacements.Add(lambdaExpression.Parameters[i], arguments[i]); + } + + return new ReplacingExpressionVisitor(replacements).Visit(lambdaExpression.Body); + } + } +} diff --git a/src/EFCore/Query/QueryTranslationPreprocessor.cs b/src/EFCore/Query/QueryTranslationPreprocessor.cs index be10f8531a9..fae6e25eaf9 100644 --- a/src/EFCore/Query/QueryTranslationPreprocessor.cs +++ b/src/EFCore/Query/QueryTranslationPreprocessor.cs @@ -26,6 +26,7 @@ public virtual Expression Process(Expression query) { query = new EnumerableToQueryableMethodConvertingExpressionVisitor().Visit(query); query = new QueryMetadataExtractingExpressionVisitor(_queryCompilationContext).Visit(query); + query = new InvocationExpressionRemovingExpressionVisitor().Visit(query); query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query); query = new GroupJoinFlatteningExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs index 7c08c7e30cb..e5effc18e73 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.Where.cs @@ -1515,9 +1515,20 @@ FROM root c WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""Fax""] = null))"); } - public override async Task Where_expression_invoke(bool isAsync) + public override async Task Where_expression_invoke_1(bool isAsync) { - await base.Where_expression_invoke(isAsync); + await base.Where_expression_invoke_1(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))"); + } + + [ConditionalTheory(Skip = "Issue #17246")] + public override async Task Where_expression_invoke_2(bool isAsync) + { + await base.Where_expression_invoke_2(isAsync); AssertSql( @"SELECT c diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs index 2f64f44a553..b601844dfc4 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs @@ -1495,9 +1495,9 @@ public virtual Task Where_default(bool isAsync) entryCount: 22); } - [ConditionalTheory(Skip = "Issue#14572")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual Task Where_expression_invoke(bool isAsync) + public virtual Task Where_expression_invoke_1(bool isAsync) { Expression> expression = c => c.CustomerID == "ALFKI"; var parameter = Expression.Parameter(typeof(Customer), "c"); @@ -1509,6 +1509,22 @@ public virtual Task Where_expression_invoke(bool isAsync) entryCount: 1); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Where_expression_invoke_2(bool isAsync) + { + Expression> customer = o => o.Customer; + Expression> predicate = c => c.CustomerID == "ALFKI"; + var exp = Expression.Lambda>( + Expression.Invoke(predicate, customer.Body), + customer.Parameters); + + return AssertQuery( + isAsync, + ss => ss.Set().Where(exp), + entryCount: 6); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Where_concat_string_int_comparison1(bool isAsync) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs index d4573596236..e5f860b8d82 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs @@ -1270,9 +1270,9 @@ FROM [Customers] AS [c] WHERE [c].[Fax] IS NULL"); } - public override async Task Where_expression_invoke(bool isAsync) + public override async Task Where_expression_invoke_1(bool isAsync) { - await base.Where_expression_invoke(isAsync); + await base.Where_expression_invoke_1(isAsync); AssertSql( @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] @@ -1280,6 +1280,17 @@ FROM [Customers] AS [c] WHERE [c].[CustomerID] = N'ALFKI'"); } + public override async Task Where_expression_invoke_2(bool isAsync) + { + await base.Where_expression_invoke_2(isAsync); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +LEFT JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID] +WHERE ([c].[CustomerID] = N'ALFKI') AND [c].[CustomerID] IS NOT NULL"); + } + public override async Task Where_concat_string_int_comparison1(bool isAsync) { await base.Where_concat_string_int_comparison1(isAsync);