Skip to content

Commit

Permalink
Apply entity equality to query filters
Browse files Browse the repository at this point in the history
Fixes #18158
  • Loading branch information
roji committed Oct 18, 2019
2 parents addd8fa + 0e12de8 commit b563040
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments);
}

protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
// We handle both Contains the extension method and the instance method
var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
Expand Down Expand Up @@ -465,7 +465,7 @@ Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
: methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
Expand Down Expand Up @@ -540,7 +540,7 @@ static MethodInfo GetOrderingMethodInfo(bool firstOrdering, bool ascending)
}
}

protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
Expand Down Expand Up @@ -587,7 +587,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
}

protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;

Expand Down Expand Up @@ -694,7 +694,7 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
: (Expression)newMethodCall;
}

protected virtual Expression VisitOfType(MethodCallExpression methodCallExpression)
private Expression VisitOfType(MethodCallExpression methodCallExpression)
{
var newSource = Visit(methodCallExpression.Arguments[0]);
var updatedMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource) });
Expand All @@ -719,7 +719,7 @@ protected virtual Expression VisitOfType(MethodCallExpression methodCallExpressi
/// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits
/// the lambda's body.
/// </summary>
protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
private LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
=> Expression.Lambda(
lambda.Type,
Visit(
Expand All @@ -734,7 +734,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda
/// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits
/// the lambda's body.
/// </summary>
protected virtual LambdaExpression RewriteAndVisitLambda(
private LambdaExpression RewriteAndVisitLambda(
LambdaExpression lambda,
EntityReferenceExpression source1,
EntityReferenceExpression source2)
Expand All @@ -758,7 +758,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(
/// if possible.
/// </summary>
/// <returns> The rewritten entity equality expression, or null if rewriting could not occur for some reason. </returns>
protected virtual Expression RewriteEquality(bool equality, Expression left, Expression right)
private Expression RewriteEquality(bool equality, Expression left, Expression right)
{
// TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model.
// TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on.
Expand Down Expand Up @@ -904,7 +904,7 @@ protected override Expression VisitExtension(Expression expression)
}
}

protected virtual Expression VisitNullConditional(NullConditionalExpression expression)
private Expression VisitNullConditional(NullConditionalExpression expression)
{
var newCaller = Visit(expression.Caller);
var newAccessOperation = Visit(expression.AccessOperation);
Expand Down
16 changes: 14 additions & 2 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public partial class NavigationExpandingExpressionVisitor : ExpressionVisitor
private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor;
private readonly ISet<string> _parameterNames = new HashSet<string>();
private readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableMethodConvertingExpressionVisitor;
private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor;
private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor;

private readonly Dictionary<IEntityType, LambdaExpression> _parameterizedQueryFilterPredicateCache
Expand All @@ -44,6 +45,7 @@ public NavigationExpandingExpressionVisitor(
_reducingExpressionVisitor = new ReducingExpressionVisitor();
_entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor();
_enumerableToQueryableMethodConvertingExpressionVisitor = new EnumerableToQueryableMethodConvertingExpressionVisitor();
_entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext);
_parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor(
evaluatableExpressionFilter,
_parameters,
Expand Down Expand Up @@ -162,7 +164,8 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
{
if (!_queryCompilationContext.IgnoreQueryFilters)
{
var entityType = _queryCompilationContext.Model.FindEntityType(navigationExpansionExpression.Type.GetSequenceType());
var sequenceType = navigationExpansionExpression.Type.GetSequenceType();
var entityType = _queryCompilationContext.Model.FindEntityType(sequenceType);
var rootEntityType = entityType.GetRootType();
var queryFilter = rootEntityType.GetQueryFilter();
if (queryFilter != null)
Expand All @@ -172,13 +175,22 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
filterPredicate = queryFilter;
filterPredicate = (LambdaExpression)_parameterExtractingExpressionVisitor.ExtractParameters(filterPredicate);
filterPredicate = (LambdaExpression)_enumerableToQueryableMethodConvertingExpressionVisitor.Visit(filterPredicate);

// We need to do entity equality, but that requires a full method call on a query root to properly flow the
// entity information through. Construct a MethodCall wrapper for the predicate with the proper query root.
var filterWrapper = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(rootEntityType.ClrType),
NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(rootEntityType.ClrType),
filterPredicate);
var rewrittenFilterWrapper = (MethodCallExpression)_entityEqualityRewritingExpressionVisitor.Rewrite(filterWrapper);
filterPredicate = rewrittenFilterWrapper.Arguments[1].UnwrapLambdaFromQuote();

_parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate;
}

filterPredicate =
(LambdaExpression)new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit(
filterPredicate);
var sequenceType = navigationExpansionExpression.Type.GetSequenceType();

// if we are constructing EntityQueryable of a derived type, we need to re-map filter predicate to the correct derived type
var filterPredicateParameter = filterPredicate.Parameters[0];
Expand Down
8 changes: 8 additions & 0 deletions test/EFCore.Specification.Tests/Query/FiltersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ public virtual void Compiled_query()
}
}

[ConditionalFact]
public virtual void Entity_Equality()
{
var results = _context.Orders.ToList();

Assert.Equal(80, results.Count);
}

protected NorthwindContext CreateContext() => Fixture.CreateContext();

public void Dispose() => _context.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void ConfigureFilters(ModelBuilder modelBuilder)
// so we can capture TenantPrefix in filter exprs (simulates OnModelCreating).

modelBuilder.Entity<Customer>().HasQueryFilter(c => c.CompanyName.StartsWith(TenantPrefix));
modelBuilder.Entity<Order>().HasQueryFilter(o => o.Customer.CompanyName != null);
modelBuilder.Entity<Order>().HasQueryFilter(o => o.Customer != null && o.Customer.CompanyName != null);
modelBuilder.Entity<OrderDetail>().HasQueryFilter(od => EF.Property<short>(od, "Quantity") > _quantity);
modelBuilder.Entity<Employee>().HasQueryFilter(e => e.Address.StartsWith("A"));
modelBuilder.Entity<Product>().HasQueryFilter(p => ClientMethod(p));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
ORDER BY [c].[CustomerID], [t0].[OrderID]");
Expand Down Expand Up @@ -156,7 +156,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL");
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

public override void Project_reference_that_itself_has_query_filter_with_another_reference()
Expand All @@ -177,7 +177,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_1 = N'') AND @__ef_filter__TenantPrefix_1 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_1 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) = @__ef_filter__TenantPrefix_1) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NOT NULL AND @__ef_filter__TenantPrefix_1 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NULL AND @__ef_filter__TenantPrefix_1 IS NULL))))
) AS [t] ON [o0].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [o].[OrderID] = [t0].[OrderID]
WHERE [o].[Quantity] > @__ef_filter___quantity_0");
}
Expand All @@ -200,7 +200,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
INNER JOIN (
SELECT [o0].[OrderID], [o0].[ProductID], [o0].[Discount], [o0].[Quantity], [o0].[UnitPrice]
Expand Down Expand Up @@ -252,7 +252,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL");
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

public override void Compiled_query()
Expand All @@ -275,6 +275,23 @@ FROM [Customers] AS [c]
WHERE (((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))) AND (([c].[CustomerID] = @__customerID) AND @__customerID IS NOT NULL)");
}

public override void Entity_Equality()
{
base.Entity_Equality();

AssertSql(
@"@__ef_filter__TenantPrefix_0='B' (Size = 4000)
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down

0 comments on commit b563040

Please sign in to comment.