diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
index 77280a56cba..3f8fe53cbc2 100644
--- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
+++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
@@ -362,7 +362,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments);
}
- protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
+ private Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
// We handle both Contains the extension method and the instance method
var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
@@ -459,7 +459,7 @@ Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
: methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}
- protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
+ private Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
@@ -534,7 +534,7 @@ static MethodInfo GetOrderingMethodInfo(bool firstOrdering, bool ascending)
}
}
- protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
+ private Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
@@ -581,7 +581,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
}
- protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
+ private Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
@@ -688,7 +688,7 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
: (Expression)newMethodCall;
}
- protected virtual Expression VisitOfType(MethodCallExpression methodCallExpression)
+ private Expression VisitOfType(MethodCallExpression methodCallExpression)
{
var newSource = Visit(methodCallExpression.Arguments[0]);
var updatedMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource) });
@@ -713,7 +713,7 @@ protected virtual Expression VisitOfType(MethodCallExpression methodCallExpressi
/// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits
/// the lambda's body.
///
- protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
+ private LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
=> Expression.Lambda(
lambda.Type,
Visit(
@@ -728,7 +728,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda
/// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits
/// the lambda's body.
///
- protected virtual LambdaExpression RewriteAndVisitLambda(
+ private LambdaExpression RewriteAndVisitLambda(
LambdaExpression lambda,
EntityReferenceExpression source1,
EntityReferenceExpression source2)
@@ -752,7 +752,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(
/// if possible.
///
/// The rewritten entity equality expression, or null if rewriting could not occur for some reason.
- protected virtual Expression RewriteEquality(bool equality, Expression left, Expression right)
+ private Expression RewriteEquality(bool equality, Expression left, Expression right)
{
// TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model.
// TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on.
@@ -898,7 +898,7 @@ protected override Expression VisitExtension(Expression expression)
}
}
- protected virtual Expression VisitNullConditional(NullConditionalExpression expression)
+ private Expression VisitNullConditional(NullConditionalExpression expression)
{
var newCaller = Visit(expression.Caller);
var newAccessOperation = Visit(expression.AccessOperation);
diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
index 78923b7b1d6..05bf3882a61 100644
--- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
+++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
@@ -23,6 +23,7 @@ public partial class NavigationExpandingExpressionVisitor : ExpressionVisitor
private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor;
private readonly ISet _parameterNames = new HashSet();
private readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableMethodConvertingExpressionVisitor;
+ private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor;
private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor;
private readonly Dictionary _parameterizedQueryFilterPredicateCache
@@ -44,6 +45,7 @@ public NavigationExpandingExpressionVisitor(
_reducingExpressionVisitor = new ReducingExpressionVisitor();
_entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor();
_enumerableToQueryableMethodConvertingExpressionVisitor = new EnumerableToQueryableMethodConvertingExpressionVisitor();
+ _entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext);
_parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor(
evaluatableExpressionFilter,
_parameters,
@@ -162,7 +164,8 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
{
if (!_queryCompilationContext.IgnoreQueryFilters)
{
- var entityType = _queryCompilationContext.Model.FindEntityType(navigationExpansionExpression.Type.GetSequenceType());
+ var sequenceType = navigationExpansionExpression.Type.GetSequenceType();
+ var entityType = _queryCompilationContext.Model.FindEntityType(sequenceType);
var rootEntityType = entityType.GetRootType();
var queryFilter = rootEntityType.GetQueryFilter();
if (queryFilter != null)
@@ -172,13 +175,22 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
filterPredicate = queryFilter;
filterPredicate = (LambdaExpression)_parameterExtractingExpressionVisitor.ExtractParameters(filterPredicate);
filterPredicate = (LambdaExpression)_enumerableToQueryableMethodConvertingExpressionVisitor.Visit(filterPredicate);
+
+ // We need to do entity equality, but that requires a full method call on a query root to properly flow the
+ // entity information through. Construct a MethodCall wrapper for the predicate with the proper query root.
+ var filterWrapper = Expression.Call(
+ QueryableMethods.Where.MakeGenericMethod(rootEntityType.ClrType),
+ NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(rootEntityType.ClrType),
+ filterPredicate);
+ var rewrittenFilterWrapper = (MethodCallExpression)_entityEqualityRewritingExpressionVisitor.Rewrite(filterWrapper);
+ filterPredicate = rewrittenFilterWrapper.Arguments[1].UnwrapLambdaFromQuote();
+
_parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate;
}
filterPredicate =
(LambdaExpression)new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit(
filterPredicate);
- var sequenceType = navigationExpansionExpression.Type.GetSequenceType();
// if we are constructing EntityQueryable of a derived type, we need to re-map filter predicate to the correct derived type
var filterPredicateParameter = filterPredicate.Parameters[0];
diff --git a/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs b/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs
index f143e58991b..f4c4dcd6578 100644
--- a/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs
+++ b/test/EFCore.Specification.Tests/Query/FiltersTestBase.cs
@@ -168,6 +168,14 @@ public virtual void Compiled_query()
}
}
+ [ConditionalFact]
+ public virtual void Entity_Equality()
+ {
+ var results = _context.Orders.ToList();
+
+ Assert.Equal(80, results.Count);
+ }
+
protected NorthwindContext CreateContext() => Fixture.CreateContext();
public void Dispose() => _context.Dispose();
diff --git a/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs b/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs
index 20416d43b48..8bd45f9cbd1 100644
--- a/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs
+++ b/test/EFCore.Specification.Tests/TestModels/Northwind/NorthwindContext.cs
@@ -134,7 +134,7 @@ public void ConfigureFilters(ModelBuilder modelBuilder)
// so we can capture TenantPrefix in filter exprs (simulates OnModelCreating).
modelBuilder.Entity().HasQueryFilter(c => c.CompanyName.StartsWith(TenantPrefix));
- modelBuilder.Entity().HasQueryFilter(o => o.Customer.CompanyName != null);
+ modelBuilder.Entity().HasQueryFilter(o => o.Customer != null && o.Customer.CompanyName != null);
modelBuilder.Entity().HasQueryFilter(od => EF.Property(od, "Quantity") > _quantity);
modelBuilder.Entity().HasQueryFilter(e => e.Address.StartsWith("A"));
modelBuilder.Entity().HasQueryFilter(p => ClientMethod(p));
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs
index 985bf869997..904d86049c3 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/FiltersSqlServerTest.cs
@@ -125,7 +125,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
- WHERE [t].[CompanyName] IS NOT NULL
+ WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
ORDER BY [c].[CustomerID], [t0].[OrderID]");
@@ -156,7 +156,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
-WHERE [t].[CompanyName] IS NOT NULL");
+WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}
public override void Project_reference_that_itself_has_query_filter_with_another_reference()
@@ -177,7 +177,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_1 = N'') AND @__ef_filter__TenantPrefix_1 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_1 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) = @__ef_filter__TenantPrefix_1) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NOT NULL AND @__ef_filter__TenantPrefix_1 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NULL AND @__ef_filter__TenantPrefix_1 IS NULL))))
) AS [t] ON [o0].[CustomerID] = [t].[CustomerID]
- WHERE [t].[CompanyName] IS NOT NULL
+ WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [o].[OrderID] = [t0].[OrderID]
WHERE [o].[Quantity] > @__ef_filter___quantity_0");
}
@@ -200,7 +200,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
- WHERE [t].[CompanyName] IS NOT NULL
+ WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
INNER JOIN (
SELECT [o0].[OrderID], [o0].[ProductID], [o0].[Discount], [o0].[Quantity], [o0].[UnitPrice]
@@ -252,7 +252,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
-WHERE [t].[CompanyName] IS NOT NULL");
+WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}
public override void Compiled_query()
@@ -275,6 +275,23 @@ FROM [Customers] AS [c]
WHERE (((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))) AND (([c].[CustomerID] = @__customerID) AND @__customerID IS NOT NULL)");
}
+ public override void Entity_Equality()
+ {
+ base.Entity_Equality();
+
+ AssertSql(
+ @"@__ef_filter__TenantPrefix_0='B' (Size = 4000)
+
+SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
+FROM [Orders] AS [o]
+LEFT JOIN (
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
+) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
+WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
+ }
+
private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}