Skip to content

Commit

Permalink
Port AllAnyToContainsRewritingExpressionVisitor to new pipeline
Browse files Browse the repository at this point in the history
Closes #15717
  • Loading branch information
roji committed May 26, 2019
1 parent b28f9ee commit b937ffa
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 165 deletions.

This file was deleted.

2 changes: 1 addition & 1 deletion src/EFCore/Query/Internal/QueryOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public virtual void Optimize(
queryModel.TransformExpressions(new EntityEqualityRewritingExpressionVisitor(queryCompilationContext).Visit);
queryModel.TransformExpressions(new SubQueryMemberPushDownExpressionVisitor(queryCompilationContext).Visit);
queryModel.TransformExpressions(new ExistsToAnyRewritingExpressionVisitor().Visit);
queryModel.TransformExpressions(new AllAnyToContainsRewritingExpressionVisitor().Visit);
//queryModel.TransformExpressions(new AllAnyToContainsRewritingExpressionVisitor().Visit);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;

namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
public class AllAnyToContainsRewritingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo _anyMethodInfo =
LinqMethodHelpers.EnumerableAnyPredicateMethodInfo.GetGenericMethodDefinition();

private static readonly MethodInfo _allMethodInfo =
LinqMethodHelpers.EnumerableAllMethodInfo.GetGenericMethodDefinition();

private static readonly MethodInfo _containsMethod =
LinqMethodHelpers.EnumerableContainsMethodInfo.GetGenericMethodDefinition();

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo methodInfo
&& (methodInfo.Equals(_anyMethodInfo) || methodInfo.Equals(_allMethodInfo))
&& methodCallExpression.Arguments[0].NodeType is ExpressionType nodeType
&& (nodeType == ExpressionType.Parameter || nodeType == ExpressionType.Constant)
&& methodCallExpression.Arguments[1] is LambdaExpression lambda
&& TryExtractEqualityOperands(lambda.Body, out var left, out var right, out var negated)
&& (left is ParameterExpression || right is ParameterExpression))
{
var nonParameterExpression = left is ParameterExpression ? right : left;

if (methodInfo.Equals(_anyMethodInfo) && !negated)
{
var containsMethod = _containsMethod.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], nonParameterExpression);
}

if (methodInfo.Equals(_allMethodInfo) && negated)
{
var containsMethod = _containsMethod.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Not(Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], nonParameterExpression));
}
}

return base.VisitMethodCall(methodCallExpression);
}

private bool TryExtractEqualityOperands(Expression expression, out Expression left, out Expression right, out bool negated)
{
(left, right, negated) = (default, default, default);

switch (expression)
{
case BinaryExpression binaryExpression:
{
switch (binaryExpression.NodeType)
{
case ExpressionType.Equal:
negated = false;
break;
case ExpressionType.NotEqual:
negated = true;
break;
default:
return false;
}

(left, right) = (binaryExpression.Left, binaryExpression.Right);
return true;
}

case MethodCallExpression methodCallExpression when methodCallExpression.Method.Name == nameof(object.Equals):
{
negated = false;
if (methodCallExpression.Arguments.Count == 1 && methodCallExpression.Object.Type == methodCallExpression.Arguments[0].Type)
{
(left, right) = (methodCallExpression.Object, methodCallExpression.Arguments[0]);
}
else if (methodCallExpression.Arguments.Count == 2 && methodCallExpression.Arguments[0].Type == methodCallExpression.Arguments[1].Type)
{
(left, right) = (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]);
}

return true;
}

case UnaryExpression unaryExpression when unaryExpression.NodeType == ExpressionType.Not:
{
var result = TryExtractEqualityOperands(unaryExpression.Operand, out left, out right, out negated);
negated = !negated;
return result;
}
}

return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public QueryOptimizer(QueryCompilationContext2 queryCompilationContext)

public Expression Visit(Expression query)
{
query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query);
query = new GroupJoinFlatteningExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new NavigationExpander(_queryCompilationContext.Model).ExpandNavigations(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ public virtual Task Project_constant_Sum(bool isAsync)
selector: e => 1);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_any_equals_operator(bool isAsync)
{
Expand All @@ -1812,24 +1812,15 @@ public virtual Task Where_subquery_any_equals_operator(bool isAsync)
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_any_equals(bool isAsync)
{
var ids = new List<string>
{
"ABCDE",
"ALFKI",
"ANATR"
};

return AssertQuery<Customer>(
=> AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => ids.Any(li => li.Equals(c.CustomerID))),
cs => cs.Where(c => new[] { "ABCDE", "ALFKI", "ANATR" }.Any(li => li.Equals(c.CustomerID))),
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_any_equals_static(bool isAsync)
{
Expand All @@ -1846,24 +1837,29 @@ public virtual Task Where_subquery_any_equals_static(bool isAsync)
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_where_any(bool isAsync)
public virtual async Task Where_subquery_where_any(bool isAsync)
{
var ids = new List<string>
var ids = new[]
{
"ABCDE",
"ALFKI",
"ANATR"
};

return AssertQuery<Customer>(
await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.Any(li => li == c.CustomerID)),
entryCount: 1);

await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.Any(li => c.CustomerID == li)),
entryCount: 1);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_all_not_equals_operator(bool isAsync)
{
Expand All @@ -1880,24 +1876,15 @@ public virtual Task Where_subquery_all_not_equals_operator(bool isAsync)
entryCount: 89);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_all_not_equals(bool isAsync)
{
var ids = new List<string>
{
"ABCDE",
"ALFKI",
"ANATR"
};

return AssertQuery<Customer>(
=> AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => ids.All(li => !li.Equals(c.CustomerID))),
cs => cs.Where(c => new List<string> { "ABCDE", "ALFKI", "ANATR" }.All(li => !li.Equals(c.CustomerID))),
entryCount: 89);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_all_not_equals_static(bool isAsync)
{
Expand All @@ -1914,9 +1901,9 @@ public virtual Task Where_subquery_all_not_equals_static(bool isAsync)
entryCount: 89);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_where_all(bool isAsync)
public virtual async Task Where_subquery_where_all(bool isAsync)
{
var ids = new List<string>
{
Expand All @@ -1925,10 +1912,15 @@ public virtual Task Where_subquery_where_all(bool isAsync)
"ANATR"
};

return AssertQuery<Customer>(
await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.All(li => li != c.CustomerID)),
entryCount: 4);

await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.All(li => c.CustomerID != li)),
entryCount: 4);
}

[ConditionalTheory]
Expand Down

0 comments on commit b937ffa

Please sign in to comment.