Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port AllAnyToContainsRewritingExpressionVisitor to new pipeline #15806

Merged
merged 1 commit into from
May 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

2 changes: 1 addition & 1 deletion src/EFCore/Query/Internal/QueryOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public virtual void Optimize(
queryModel.TransformExpressions(new EntityEqualityRewritingExpressionVisitor(queryCompilationContext).Visit);
queryModel.TransformExpressions(new SubQueryMemberPushDownExpressionVisitor(queryCompilationContext).Visit);
queryModel.TransformExpressions(new ExistsToAnyRewritingExpressionVisitor().Visit);
queryModel.TransformExpressions(new AllAnyToContainsRewritingExpressionVisitor().Visit);
//queryModel.TransformExpressions(new AllAnyToContainsRewritingExpressionVisitor().Visit);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;

namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
public class AllAnyToContainsRewritingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo _anyMethodInfo =
LinqMethodHelpers.EnumerableAnyPredicateMethodInfo.GetGenericMethodDefinition();

private static readonly MethodInfo _allMethodInfo =
LinqMethodHelpers.EnumerableAllMethodInfo.GetGenericMethodDefinition();

private static readonly MethodInfo _containsMethod =
LinqMethodHelpers.EnumerableContainsMethodInfo.GetGenericMethodDefinition();

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo methodInfo
&& (methodInfo.Equals(_anyMethodInfo) || methodInfo.Equals(_allMethodInfo))
&& methodCallExpression.Arguments[0].NodeType is ExpressionType nodeType
&& (nodeType == ExpressionType.Parameter || nodeType == ExpressionType.Constant)
&& methodCallExpression.Arguments[1] is LambdaExpression lambda
&& TryExtractEqualityOperands(lambda.Body, out var left, out var right, out var negated)
&& (left is ParameterExpression || right is ParameterExpression))
{
var nonParameterExpression = left is ParameterExpression ? right : left;

if (methodInfo.Equals(_anyMethodInfo) && !negated)
{
var containsMethod = _containsMethod.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], nonParameterExpression);
}

if (methodInfo.Equals(_allMethodInfo) && negated)
{
var containsMethod = _containsMethod.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]);
return Expression.Not(Expression.Call(null, containsMethod, methodCallExpression.Arguments[0], nonParameterExpression));
}
}

return base.VisitMethodCall(methodCallExpression);
}

private bool TryExtractEqualityOperands(Expression expression, out Expression left, out Expression right, out bool negated)
{
(left, right, negated) = (default, default, default);

switch (expression)
{
case BinaryExpression binaryExpression:
{
switch (binaryExpression.NodeType)
{
case ExpressionType.Equal:
negated = false;
break;
case ExpressionType.NotEqual:
negated = true;
break;
default:
return false;
}

(left, right) = (binaryExpression.Left, binaryExpression.Right);
return true;
}

case MethodCallExpression methodCallExpression when methodCallExpression.Method.Name == nameof(object.Equals):
{
negated = false;
if (methodCallExpression.Arguments.Count == 1 && methodCallExpression.Object.Type == methodCallExpression.Arguments[0].Type)
{
(left, right) = (methodCallExpression.Object, methodCallExpression.Arguments[0]);
}
else if (methodCallExpression.Arguments.Count == 2 && methodCallExpression.Arguments[0].Type == methodCallExpression.Arguments[1].Type)
{
(left, right) = (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]);
}

return true;
}

case UnaryExpression unaryExpression when unaryExpression.NodeType == ExpressionType.Not:
{
var result = TryExtractEqualityOperands(unaryExpression.Operand, out left, out right, out negated);
negated = !negated;
return result;
}
}

return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public QueryOptimizer(QueryCompilationContext2 queryCompilationContext)

public Expression Visit(Expression query)
{
query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query);
query = new GroupJoinFlatteningExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new NavigationExpander(_queryCompilationContext.Model).ExpandNavigations(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ public virtual Task Project_constant_Sum(bool isAsync)
selector: e => 1);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_any_equals_operator(bool isAsync)
{
Expand All @@ -1812,24 +1812,15 @@ public virtual Task Where_subquery_any_equals_operator(bool isAsync)
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_any_equals(bool isAsync)
{
var ids = new List<string>
{
"ABCDE",
"ALFKI",
"ANATR"
};

return AssertQuery<Customer>(
=> AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => ids.Any(li => li.Equals(c.CustomerID))),
cs => cs.Where(c => new[] { "ABCDE", "ALFKI", "ANATR" }.Any(li => li.Equals(c.CustomerID))),
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_any_equals_static(bool isAsync)
{
Expand All @@ -1846,24 +1837,29 @@ public virtual Task Where_subquery_any_equals_static(bool isAsync)
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_where_any(bool isAsync)
public virtual async Task Where_subquery_where_any(bool isAsync)
{
var ids = new List<string>
var ids = new[]
{
"ABCDE",
"ALFKI",
"ANATR"
};

return AssertQuery<Customer>(
await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.Any(li => li == c.CustomerID)),
entryCount: 1);

await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.Any(li => c.CustomerID == li)),
entryCount: 1);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_all_not_equals_operator(bool isAsync)
{
Expand All @@ -1880,24 +1876,15 @@ public virtual Task Where_subquery_all_not_equals_operator(bool isAsync)
entryCount: 89);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_all_not_equals(bool isAsync)
{
var ids = new List<string>
{
"ABCDE",
"ALFKI",
"ANATR"
};

return AssertQuery<Customer>(
=> AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => ids.All(li => !li.Equals(c.CustomerID))),
cs => cs.Where(c => new List<string> { "ABCDE", "ALFKI", "ANATR" }.All(li => !li.Equals(c.CustomerID))),
entryCount: 89);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_all_not_equals_static(bool isAsync)
{
Expand All @@ -1914,9 +1901,9 @@ public virtual Task Where_subquery_all_not_equals_static(bool isAsync)
entryCount: 89);
}

[ConditionalTheory(Skip = "Issue#15717")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_where_all(bool isAsync)
public virtual async Task Where_subquery_where_all(bool isAsync)
{
var ids = new List<string>
{
Expand All @@ -1925,10 +1912,15 @@ public virtual Task Where_subquery_where_all(bool isAsync)
"ANATR"
};

return AssertQuery<Customer>(
await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.All(li => li != c.CustomerID)),
entryCount: 4);

await AssertQuery<Customer>(
isAsync,
cs => cs.Where(c => c.City == "México D.F.").Where(c => ids.All(li => c.CustomerID != li)),
entryCount: 4);
}

[ConditionalTheory]
Expand Down