Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply entity equality to query filters #18426

Merged
merged 2 commits into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments);
}

protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
// We handle both Contains the extension method and the instance method
var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
Expand Down Expand Up @@ -459,7 +459,7 @@ Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
: methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
Expand Down Expand Up @@ -534,7 +534,7 @@ static MethodInfo GetOrderingMethodInfo(bool firstOrdering, bool ascending)
}
}

protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
Expand Down Expand Up @@ -581,7 +581,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
}

protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
private Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;

Expand Down Expand Up @@ -688,7 +688,7 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
: (Expression)newMethodCall;
}

protected virtual Expression VisitOfType(MethodCallExpression methodCallExpression)
private Expression VisitOfType(MethodCallExpression methodCallExpression)
{
var newSource = Visit(methodCallExpression.Arguments[0]);
var updatedMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource) });
Expand All @@ -713,7 +713,7 @@ protected virtual Expression VisitOfType(MethodCallExpression methodCallExpressi
/// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits
/// the lambda's body.
/// </summary>
protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
private LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
=> Expression.Lambda(
lambda.Type,
Visit(
Expand All @@ -728,7 +728,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda
/// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits
/// the lambda's body.
/// </summary>
protected virtual LambdaExpression RewriteAndVisitLambda(
private LambdaExpression RewriteAndVisitLambda(
LambdaExpression lambda,
EntityReferenceExpression source1,
EntityReferenceExpression source2)
Expand All @@ -752,7 +752,7 @@ protected virtual LambdaExpression RewriteAndVisitLambda(
/// if possible.
/// </summary>
/// <returns> The rewritten entity equality expression, or null if rewriting could not occur for some reason. </returns>
protected virtual Expression RewriteEquality(bool equality, Expression left, Expression right)
private Expression RewriteEquality(bool equality, Expression left, Expression right)
{
// TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model.
// TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on.
Expand Down Expand Up @@ -898,7 +898,7 @@ protected override Expression VisitExtension(Expression expression)
}
}

protected virtual Expression VisitNullConditional(NullConditionalExpression expression)
private Expression VisitNullConditional(NullConditionalExpression expression)
{
var newCaller = Visit(expression.Caller);
var newAccessOperation = Visit(expression.AccessOperation);
Expand Down
16 changes: 14 additions & 2 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public partial class NavigationExpandingExpressionVisitor : ExpressionVisitor
private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor;
private readonly ISet<string> _parameterNames = new HashSet<string>();
private readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableMethodConvertingExpressionVisitor;
private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor;
private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor;

private readonly Dictionary<IEntityType, LambdaExpression> _parameterizedQueryFilterPredicateCache
Expand All @@ -44,6 +45,7 @@ public NavigationExpandingExpressionVisitor(
_reducingExpressionVisitor = new ReducingExpressionVisitor();
_entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor();
_enumerableToQueryableMethodConvertingExpressionVisitor = new EnumerableToQueryableMethodConvertingExpressionVisitor();
_entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext);
_parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor(
evaluatableExpressionFilter,
_parameters,
Expand Down Expand Up @@ -162,7 +164,8 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
{
if (!_queryCompilationContext.IgnoreQueryFilters)
{
var entityType = _queryCompilationContext.Model.FindEntityType(navigationExpansionExpression.Type.GetSequenceType());
var sequenceType = navigationExpansionExpression.Type.GetSequenceType();
var entityType = _queryCompilationContext.Model.FindEntityType(sequenceType);
var rootEntityType = entityType.GetRootType();
var queryFilter = rootEntityType.GetQueryFilter();
if (queryFilter != null)
Expand All @@ -172,13 +175,22 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
filterPredicate = queryFilter;
filterPredicate = (LambdaExpression)_parameterExtractingExpressionVisitor.ExtractParameters(filterPredicate);
filterPredicate = (LambdaExpression)_enumerableToQueryableMethodConvertingExpressionVisitor.Visit(filterPredicate);

// We need to do entity equality, but that requires a full method call on a query root to properly flow the
// entity information through. Construct a MethodCall wrapper for the predicate with the proper query root.
var filterWrapper = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(rootEntityType.ClrType),
NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(rootEntityType.ClrType),
filterPredicate);
var rewrittenFilterWrapper = (MethodCallExpression)_entityEqualityRewritingExpressionVisitor.Rewrite(filterWrapper);
filterPredicate = rewrittenFilterWrapper.Arguments[1].UnwrapLambdaFromQuote();

_parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate;
}

filterPredicate =
(LambdaExpression)new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit(
filterPredicate);
var sequenceType = navigationExpansionExpression.Type.GetSequenceType();

// if we are constructing EntityQueryable of a derived type, we need to re-map filter predicate to the correct derived type
var filterPredicateParameter = filterPredicate.Parameters[0];
Expand Down
8 changes: 8 additions & 0 deletions test/EFCore.Specification.Tests/Query/FiltersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ public virtual void Compiled_query()
}
}

[ConditionalFact]
public virtual void Entity_Equality()
{
var results = _context.Orders.ToList();

Assert.Equal(80, results.Count);
}

protected NorthwindContext CreateContext() => Fixture.CreateContext();

public void Dispose() => _context.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void ConfigureFilters(ModelBuilder modelBuilder)
// so we can capture TenantPrefix in filter exprs (simulates OnModelCreating).

modelBuilder.Entity<Customer>().HasQueryFilter(c => c.CompanyName.StartsWith(TenantPrefix));
modelBuilder.Entity<Order>().HasQueryFilter(o => o.Customer.CompanyName != null);
modelBuilder.Entity<Order>().HasQueryFilter(o => o.Customer != null && o.Customer.CompanyName != null);
modelBuilder.Entity<OrderDetail>().HasQueryFilter(od => EF.Property<short>(od, "Quantity") > _quantity);
modelBuilder.Entity<Employee>().HasQueryFilter(e => e.Address.StartsWith("A"));
modelBuilder.Entity<Product>().HasQueryFilter(p => ClientMethod(p));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
ORDER BY [c].[CustomerID], [t0].[OrderID]");
Expand Down Expand Up @@ -156,7 +156,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL");
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

public override void Project_reference_that_itself_has_query_filter_with_another_reference()
Expand All @@ -177,7 +177,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_1 = N'') AND @__ef_filter__TenantPrefix_1 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_1 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) = @__ef_filter__TenantPrefix_1) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NOT NULL AND @__ef_filter__TenantPrefix_1 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NULL AND @__ef_filter__TenantPrefix_1 IS NULL))))
) AS [t] ON [o0].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [o].[OrderID] = [t0].[OrderID]
WHERE [o].[Quantity] > @__ef_filter___quantity_0");
}
Expand All @@ -200,7 +200,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
INNER JOIN (
SELECT [o0].[OrderID], [o0].[ProductID], [o0].[Discount], [o0].[Quantity], [o0].[UnitPrice]
Expand Down Expand Up @@ -252,7 +252,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL");
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

public override void Compiled_query()
Expand All @@ -275,6 +275,23 @@ FROM [Customers] AS [c]
WHERE (((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))) AND (([c].[CustomerID] = @__customerID) AND @__customerID IS NOT NULL)");
}

public override void Entity_Equality()
{
base.Entity_Equality();

AssertSql(
@"@__ef_filter__TenantPrefix_0='B' (Size = 4000)

SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down