diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index 917e59753e9..1378e2d1bf9 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -368,7 +368,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments); } - protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression) + private Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression) { // We handle both Contains the extension method and the instance method var (newSource, newItem) = methodCallExpression.Arguments.Count == 2 @@ -465,7 +465,7 @@ Expression NoTranslation() => methodCallExpression.Arguments.Count == 2 : methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) }); } - protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression) + private Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression) { var arguments = methodCallExpression.Arguments; var newSource = Visit(arguments[0]); @@ -540,7 +540,7 @@ static MethodInfo GetOrderingMethodInfo(bool firstOrdering, bool ascending) } } - protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression) + private Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression) { var arguments = methodCallExpression.Arguments; var newSource = Visit(arguments[0]); @@ -587,7 +587,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name)); } - protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression) + private Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression) { var arguments = methodCallExpression.Arguments; @@ -694,7 +694,7 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall : (Expression)newMethodCall; } - protected virtual Expression VisitOfType(MethodCallExpression methodCallExpression) + private Expression VisitOfType(MethodCallExpression methodCallExpression) { var newSource = Visit(methodCallExpression.Arguments[0]); var updatedMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource) }); @@ -719,7 +719,7 @@ protected virtual Expression VisitOfType(MethodCallExpression methodCallExpressi /// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits /// the lambda's body. /// - protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source) + private LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source) => Expression.Lambda( lambda.Type, Visit( @@ -734,7 +734,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda /// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits /// the lambda's body. /// - protected virtual LambdaExpression RewriteAndVisitLambda( + private LambdaExpression RewriteAndVisitLambda( LambdaExpression lambda, EntityReferenceExpression source1, EntityReferenceExpression source2) @@ -758,7 +758,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda( /// if possible. /// /// The rewritten entity equality expression, or null if rewriting could not occur for some reason. - protected virtual Expression RewriteEquality(bool equality, Expression left, Expression right) + private Expression RewriteEquality(bool equality, Expression left, Expression right) { // TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model. // TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on. @@ -904,7 +904,7 @@ protected override Expression VisitExtension(Expression expression) } } - protected virtual Expression VisitNullConditional(NullConditionalExpression expression) + private Expression VisitNullConditional(NullConditionalExpression expression) { var newCaller = Visit(expression.Caller); var newAccessOperation = Visit(expression.AccessOperation); diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index 70d1b8d804f..b2ef6a47848 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -23,6 +23,7 @@ public partial class NavigationExpandingExpressionVisitor : ExpressionVisitor private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor; private readonly ISet _parameterNames = new HashSet(); private readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableMethodConvertingExpressionVisitor; + private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor; private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor; private readonly Dictionary _parameterizedQueryFilterPredicateCache @@ -44,6 +45,7 @@ public NavigationExpandingExpressionVisitor( _reducingExpressionVisitor = new ReducingExpressionVisitor(); _entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor(); _enumerableToQueryableMethodConvertingExpressionVisitor = new EnumerableToQueryableMethodConvertingExpressionVisitor(); + _entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext); _parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor( evaluatableExpressionFilter, _parameters, @@ -162,7 +164,8 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa { if (!_queryCompilationContext.IgnoreQueryFilters) { - var entityType = _queryCompilationContext.Model.FindEntityType(navigationExpansionExpression.Type.GetSequenceType()); + var sequenceType = navigationExpansionExpression.Type.GetSequenceType(); + var entityType = _queryCompilationContext.Model.FindEntityType(sequenceType); var rootEntityType = entityType.GetRootType(); var queryFilter = rootEntityType.GetQueryFilter(); if (queryFilter != null) @@ -172,13 +175,22 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa filterPredicate = queryFilter; filterPredicate = (LambdaExpression)_parameterExtractingExpressionVisitor.ExtractParameters(filterPredicate); filterPredicate = (LambdaExpression)_enumerableToQueryableMethodConvertingExpressionVisitor.Visit(filterPredicate); + + // We need to do entity equality, but that requires a full method call on a query root to properly flow the + // entity information through. Construct a MethodCall wrapper for the predicate with the proper query root. + var filterWrapper = Expression.Call( + QueryableMethods.Where.MakeGenericMethod(rootEntityType.ClrType), + NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(rootEntityType.ClrType), + filterPredicate); + var rewrittenFilterWrapper = (MethodCallExpression)_entityEqualityRewritingExpressionVisitor.Rewrite(filterWrapper); + filterPredicate = rewrittenFilterWrapper.Arguments[1].UnwrapLambdaFromQuote(); + _parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate; } filterPredicate = (LambdaExpression)new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit( filterPredicate); - var sequenceType = navigationExpansionExpression.Type.GetSequenceType(); // if we are constructing EntityQueryable of a derived type, we need to re-map filter predicate to the correct derived type var filterPredicateParameter = filterPredicate.Parameters[0]; diff --git a/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs b/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs index 9661b10f21a..ba77c1ae6f4 100644 --- a/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs @@ -168,6 +168,14 @@ public virtual void Compiled_query() } } + [ConditionalFact] + public virtual void Entity_Equality() + { + var results = _context.Orders.ToList(); + + Assert.Equal(80, results.Count); + } + protected NorthwindContext CreateContext() => Fixture.CreateContext(); public void Dispose() => _context.Dispose(); diff --git a/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs b/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs index 20416d43b48..8bd45f9cbd1 100644 --- a/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs +++ b/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs @@ -134,7 +134,7 @@ public void ConfigureFilters(ModelBuilder modelBuilder) // so we can capture TenantPrefix in filter exprs (simulates OnModelCreating). modelBuilder.Entity().HasQueryFilter(c => c.CompanyName.StartsWith(TenantPrefix)); - modelBuilder.Entity().HasQueryFilter(o => o.Customer.CompanyName != null); + modelBuilder.Entity().HasQueryFilter(o => o.Customer != null && o.Customer.CompanyName != null); modelBuilder.Entity().HasQueryFilter(od => EF.Property(od, "Quantity") > _quantity); modelBuilder.Entity().HasQueryFilter(e => e.Address.StartsWith("A")); modelBuilder.Entity().HasQueryFilter(p => ClientMethod(p)); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs index 985bf869997..904d86049c3 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs @@ -125,7 +125,7 @@ LEFT JOIN ( FROM [Customers] AS [c0] WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL)))) ) AS [t] ON [o].[CustomerID] = [t].[CustomerID] - WHERE [t].[CompanyName] IS NOT NULL + WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL ) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID] WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL)))) ORDER BY [c].[CustomerID], [t0].[OrderID]"); @@ -156,7 +156,7 @@ LEFT JOIN ( FROM [Customers] AS [c] WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL)))) ) AS [t] ON [o].[CustomerID] = [t].[CustomerID] -WHERE [t].[CompanyName] IS NOT NULL"); +WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL"); } public override void Project_reference_that_itself_has_query_filter_with_another_reference() @@ -177,7 +177,7 @@ LEFT JOIN ( FROM [Customers] AS [c] WHERE ((@__ef_filter__TenantPrefix_1 = N'') AND @__ef_filter__TenantPrefix_1 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_1 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) = @__ef_filter__TenantPrefix_1) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NOT NULL AND @__ef_filter__TenantPrefix_1 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NULL AND @__ef_filter__TenantPrefix_1 IS NULL)))) ) AS [t] ON [o0].[CustomerID] = [t].[CustomerID] - WHERE [t].[CompanyName] IS NOT NULL + WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL ) AS [t0] ON [o].[OrderID] = [t0].[OrderID] WHERE [o].[Quantity] > @__ef_filter___quantity_0"); } @@ -200,7 +200,7 @@ LEFT JOIN ( FROM [Customers] AS [c0] WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL)))) ) AS [t] ON [o].[CustomerID] = [t].[CustomerID] - WHERE [t].[CompanyName] IS NOT NULL + WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL ) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID] INNER JOIN ( SELECT [o0].[OrderID], [o0].[ProductID], [o0].[Discount], [o0].[Quantity], [o0].[UnitPrice] @@ -252,7 +252,7 @@ LEFT JOIN ( FROM [Customers] AS [c] WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL)))) ) AS [t] ON [o].[CustomerID] = [t].[CustomerID] -WHERE [t].[CompanyName] IS NOT NULL"); +WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL"); } public override void Compiled_query() @@ -275,6 +275,23 @@ FROM [Customers] AS [c] WHERE (((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))) AND (([c].[CustomerID] = @__customerID) AND @__customerID IS NOT NULL)"); } + public override void Entity_Equality() + { + base.Entity_Equality(); + + AssertSql( + @"@__ef_filter__TenantPrefix_0='B' (Size = 4000) + +SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +LEFT JOIN ( + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL)))) +) AS [t] ON [o].[CustomerID] = [t].[CustomerID] +WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); }