Skip to content

Commit

Permalink
Apply entity equality to query filters
Browse files Browse the repository at this point in the history
Fixes #18158
  • Loading branch information
roji committed Oct 17, 2019
1 parent 37de58b commit aadfe52
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
// Visit all arguments, rewriting the single lambda to replace its parameter expression
newArguments[i] = arguments[i].GetLambdaOrNull() is LambdaExpression lambda
? Unwrap(RewriteAndVisitLambda(lambda, newSourceWrapper))
? Unwrap(RewriteSingleParamLambda(lambda, newSourceWrapper))
: Unwrap(Visit(arguments[i]));
}

Expand Down Expand Up @@ -469,7 +469,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
return methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) });
}

var newKeySelector = RewriteAndVisitLambda(arguments[1].UnwrapLambdaFromQuote(), sourceWrapper);
var newKeySelector = RewriteSingleParamLambda(arguments[1].UnwrapLambdaFromQuote(), sourceWrapper);

if (!(newKeySelector.Body is EntityReferenceExpression keySelectorWrapper)
|| !(keySelectorWrapper.EntityType is IEntityType entityType))
Expand Down Expand Up @@ -553,7 +553,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
if (arguments.Count == 2)
{
var selector = arguments[1].UnwrapLambdaFromQuote();
var newSelector = RewriteAndVisitLambda(selector, sourceWrapper);
var newSelector = RewriteSingleParamLambda(selector, sourceWrapper);

newMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newSelector) });
return newSelector.Body is EntityReferenceExpression entityWrapper
Expand All @@ -564,11 +564,11 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
if (arguments.Count == 3)
{
var collectionSelector = arguments[1].UnwrapLambdaFromQuote();
var newCollectionSelector = RewriteAndVisitLambda(collectionSelector, sourceWrapper);
var newCollectionSelector = RewriteSingleParamLambda(collectionSelector, sourceWrapper);

var resultSelector = arguments[2].UnwrapLambdaFromQuote();
var newResultSelector = newCollectionSelector.Body is EntityReferenceExpression newCollectionSelectorWrapper
? RewriteAndVisitLambda(resultSelector, sourceWrapper, newCollectionSelectorWrapper)
? RewriteTwoParamLambda(resultSelector, sourceWrapper, newCollectionSelectorWrapper)
: (LambdaExpression)Visit(resultSelector);

newMethodCall = methodCallExpression.Update(
Expand Down Expand Up @@ -610,9 +610,9 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
});
}

var newOuterKeySelector = RewriteAndVisitLambda(outerKeySelector, outerWrapper);
var newInnerKeySelector = RewriteAndVisitLambda(innerKeySelector, innerWrapper);
var newResultSelector = RewriteAndVisitLambda(resultSelector, outerWrapper, innerWrapper);
var newOuterKeySelector = RewriteSingleParamLambda(outerKeySelector, outerWrapper);
var newInnerKeySelector = RewriteSingleParamLambda(innerKeySelector, innerWrapper);
var newResultSelector = RewriteTwoParamLambda(resultSelector, outerWrapper, innerWrapper);

MethodCallExpression newMethodCall;

Expand Down Expand Up @@ -709,12 +709,21 @@ protected virtual Expression VisitOfType(MethodCallExpression methodCallExpressi
return new EntityReferenceExpression(updatedMethodCall, castEntityType);
}

public virtual LambdaExpression RewriteSingleParamLambda(LambdaExpression lambda, IEntityType entityType)
=> RewriteSingleParamLambda(lambda, new EntityReferenceExpression(lambda.Parameters[0], entityType));

/// <summary>
/// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits
/// the lambda's body.
/// </summary>
protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
=> Expression.Lambda(
protected virtual LambdaExpression RewriteSingleParamLambda(LambdaExpression lambda, EntityReferenceExpression source)
{
if (lambda.Parameters.Count != 1)
{
throw new ArgumentException("Lambda must have exactly one parameter", nameof(lambda));
}

return Expression.Lambda(
lambda.Type,
Visit(
ReplacingExpressionVisitor.Replace(
Expand All @@ -723,16 +732,22 @@ protected virtual LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda
lambda.Body)),
lambda.TailCall,
lambda.Parameters);
}

/// <summary>
/// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits
/// the lambda's body.
/// </summary>
protected virtual LambdaExpression RewriteAndVisitLambda(
protected virtual LambdaExpression RewriteTwoParamLambda(
LambdaExpression lambda,
EntityReferenceExpression source1,
EntityReferenceExpression source2)
{
if (lambda.Parameters.Count != 2)
{
throw new ArgumentException("Lambda must have exactly two parameters", nameof(lambda));
}

Expression original1 = lambda.Parameters[0];
Expression replacement1 = source1.Update(lambda.Parameters[0]);
Expression original2 = lambda.Parameters[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public partial class NavigationExpandingExpressionVisitor : ExpressionVisitor
private readonly EntityReferenceOptionalMarkingExpressionVisitor _entityReferenceOptionalMarkingExpressionVisitor;
private readonly ISet<string> _parameterNames = new HashSet<string>();
private readonly EnumerableToQueryableMethodConvertingExpressionVisitor _enumerableToQueryableMethodConvertingExpressionVisitor;
private readonly EntityEqualityRewritingExpressionVisitor _entityEqualityRewritingExpressionVisitor;
private readonly ParameterExtractingExpressionVisitor _parameterExtractingExpressionVisitor;

private readonly Dictionary<IEntityType, LambdaExpression> _parameterizedQueryFilterPredicateCache
Expand All @@ -44,6 +45,7 @@ public NavigationExpandingExpressionVisitor(
_reducingExpressionVisitor = new ReducingExpressionVisitor();
_entityReferenceOptionalMarkingExpressionVisitor = new EntityReferenceOptionalMarkingExpressionVisitor();
_enumerableToQueryableMethodConvertingExpressionVisitor = new EnumerableToQueryableMethodConvertingExpressionVisitor();
_entityEqualityRewritingExpressionVisitor = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext);
_parameterExtractingExpressionVisitor = new ParameterExtractingExpressionVisitor(
evaluatableExpressionFilter,
_parameters,
Expand Down Expand Up @@ -172,6 +174,7 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa
filterPredicate = queryFilter;
filterPredicate = (LambdaExpression)_parameterExtractingExpressionVisitor.ExtractParameters(filterPredicate);
filterPredicate = (LambdaExpression)_enumerableToQueryableMethodConvertingExpressionVisitor.Visit(filterPredicate);
filterPredicate = _entityEqualityRewritingExpressionVisitor.RewriteSingleParamLambda(filterPredicate, rootEntityType);
_parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate;
}

Expand Down
8 changes: 8 additions & 0 deletions test/EFCore.Specification.Tests/Query/FiltersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ public virtual void Compiled_query()
}
}

[ConditionalFact]
public virtual void Entity_Equality()
{
var results = _context.Orders.ToList();

Assert.Equal(80, results.Count);
}

protected NorthwindContext CreateContext() => Fixture.CreateContext();

public void Dispose() => _context.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void ConfigureFilters(ModelBuilder modelBuilder)
// so we can capture TenantPrefix in filter exprs (simulates OnModelCreating).

modelBuilder.Entity<Customer>().HasQueryFilter(c => c.CompanyName.StartsWith(TenantPrefix));
modelBuilder.Entity<Order>().HasQueryFilter(o => o.Customer.CompanyName != null);
modelBuilder.Entity<Order>().HasQueryFilter(o => o.Customer != null && o.Customer.CompanyName != null);
modelBuilder.Entity<OrderDetail>().HasQueryFilter(od => EF.Property<short>(od, "Quantity") > _quantity);
modelBuilder.Entity<Employee>().HasQueryFilter(e => e.Address.StartsWith("A"));
modelBuilder.Entity<Product>().HasQueryFilter(p => ClientMethod(p));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
ORDER BY [c].[CustomerID], [t0].[OrderID]");
Expand Down Expand Up @@ -156,7 +156,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL");
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

public override void Project_reference_that_itself_has_query_filter_with_another_reference()
Expand All @@ -177,7 +177,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_1 = N'') AND @__ef_filter__TenantPrefix_1 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_1 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) = @__ef_filter__TenantPrefix_1) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NOT NULL AND @__ef_filter__TenantPrefix_1 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_1)) IS NULL AND @__ef_filter__TenantPrefix_1 IS NULL))))
) AS [t] ON [o0].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [o].[OrderID] = [t0].[OrderID]
WHERE [o].[Quantity] > @__ef_filter___quantity_0");
}
Expand All @@ -200,7 +200,7 @@ LEFT JOIN (
FROM [Customers] AS [c0]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c0].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
INNER JOIN (
SELECT [o0].[OrderID], [o0].[ProductID], [o0].[Discount], [o0].[Quantity], [o0].[UnitPrice]
Expand Down Expand Up @@ -252,7 +252,7 @@ LEFT JOIN (
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CompanyName] IS NOT NULL");
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

public override void Compiled_query()
Expand All @@ -275,6 +275,23 @@ FROM [Customers] AS [c]
WHERE (((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))) AND (([c].[CustomerID] = @__customerID) AND @__customerID IS NOT NULL)");
}

public override void Entity_Equality()
{
base.Entity_Equality();

AssertSql(
@"@__ef_filter__TenantPrefix_0='B' (Size = 4000)
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE ((@__ef_filter__TenantPrefix_0 = N'') AND @__ef_filter__TenantPrefix_0 IS NOT NULL) OR ([c].[CompanyName] IS NOT NULL AND (@__ef_filter__TenantPrefix_0 IS NOT NULL AND (((LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NOT NULL AND @__ef_filter__TenantPrefix_0 IS NOT NULL)) OR (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) IS NULL AND @__ef_filter__TenantPrefix_0 IS NULL))))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CustomerID] IS NOT NULL AND [t].[CompanyName] IS NOT NULL");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down

0 comments on commit aadfe52

Please sign in to comment.