Skip to content

Commit

Permalink
Query: Refactor SelectExpression for referential integrity
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed Mar 28, 2017
1 parent 69d25a0 commit bb66f16
Show file tree
Hide file tree
Showing 36 changed files with 1,680 additions and 1,439 deletions.
43 changes: 24 additions & 19 deletions src/EFCore.Relational/Extensions/RelationalExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,42 @@
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.Expressions;
using Microsoft.EntityFrameworkCore.Utilities;

// ReSharper disable once CheckNamespace
namespace Microsoft.EntityFrameworkCore.Internal
{
public static class RelationalExpressionExtensions
{
public static ColumnExpression TryGetColumnExpression([NotNull] this Expression expression)
=> expression as ColumnExpression
?? (expression as AliasExpression)?.TryGetColumnExpression()
?? (expression as NullableExpression)?.Operand.TryGetColumnExpression();

public static bool IsAliasWithColumnExpression([NotNull] this Expression expression)
=> (expression as AliasExpression)?.Expression is ColumnExpression;

public static bool IsAliasWithSelectExpression([NotNull] this Expression expression)
=> (expression as AliasExpression)?.Expression is SelectExpression;

public static bool HasColumnExpression([CanBeNull] this AliasExpression aliasExpression)
=> aliasExpression?.Expression is ColumnExpression;

public static ColumnExpression TryGetColumnExpression([NotNull] this AliasExpression aliasExpression)
=> aliasExpression.Expression as ColumnExpression;

public static bool IsSimpleExpression([NotNull] this Expression expression)
{
Check.NotNull(expression, nameof(expression));

var unwrappedExpression = expression.RemoveConvert();

return unwrappedExpression is ConstantExpression
|| unwrappedExpression is ColumnExpression
|| unwrappedExpression is ParameterExpression
|| unwrappedExpression.IsAliasWithColumnExpression();
|| unwrappedExpression is ColumnReferenceExpression
|| unwrappedExpression is AliasExpression;
}

public static ColumnReferenceExpression LiftExpressionFromSubquery([NotNull] this Expression expression, [NotNull] TableExpressionBase table)
{
Check.NotNull(expression, nameof(expression));
Check.NotNull(table, nameof(table));

switch (expression)
{
case ColumnExpression columnExpression:
return new ColumnReferenceExpression(columnExpression, table);
case AliasExpression aliasExpression:
return new ColumnReferenceExpression(aliasExpression, table);
case ColumnReferenceExpression columnReferenceExpression:
return new ColumnReferenceExpression(columnReferenceExpression, table);
}

return null;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions;
using Microsoft.EntityFrameworkCore.Utilities;
using Remotion.Linq.Parsing;
Expand All @@ -21,68 +20,67 @@ public class EqualityPredicateInExpressionOptimizer : RelinqExpressionVisitor
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override Expression VisitBinary(BinaryExpression node)
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
Check.NotNull(node, nameof(node));
Check.NotNull(binaryExpression, nameof(binaryExpression));

switch (node.NodeType)
switch (binaryExpression.NodeType)
{
case ExpressionType.OrElse:
{
return Optimize(
node,
binaryExpression,
equalityType: ExpressionType.Equal,
inExpressionFactory: (c, vs) => new InExpression(c, vs));
}

case ExpressionType.AndAlso:
{
return Optimize(
node,
binaryExpression,
equalityType: ExpressionType.NotEqual,
inExpressionFactory: (c, vs) => Expression.Not(new InExpression(c, vs)));
}
}

return base.VisitBinary(node);
return base.VisitBinary(binaryExpression);
}

private Expression Optimize(
BinaryExpression binaryExpression,
ExpressionType equalityType,
Func<ColumnExpression, List<Expression>, Expression> inExpressionFactory)
Func<Expression, List<Expression>, Expression> inExpressionFactory)
{
var leftExpression = Visit(binaryExpression.Left);
var rightExpression = Visit(binaryExpression.Right);

Expression leftNonColumnExpression, rightNonColumnExpression;
IReadOnlyList<Expression> leftInValues = null;
IReadOnlyList<Expression> rightInValues = null;

var leftColumnExpression
= MatchEqualityExpression(
leftExpression,
equalityType,
out leftNonColumnExpression);
out Expression leftNonColumnExpression);

var rightColumnExpression
= MatchEqualityExpression(
rightExpression,
equalityType,
out rightNonColumnExpression);
out Expression rightNonColumnExpression);

if (leftColumnExpression == null)
{
leftColumnExpression = ((equalityType == ExpressionType.Equal
leftColumnExpression = equalityType == ExpressionType.Equal
? MatchInExpression(leftExpression, ref leftInValues)
: MatchNotInExpression(leftExpression, ref leftInValues))).TryGetColumnExpression();
: MatchNotInExpression(leftExpression, ref leftInValues);
}

if (rightColumnExpression == null)
{
rightColumnExpression = ((equalityType == ExpressionType.Equal
rightColumnExpression = equalityType == ExpressionType.Equal
? MatchInExpression(rightExpression, ref rightInValues)
: MatchNotInExpression(rightExpression, ref rightInValues))).TryGetColumnExpression();
: MatchNotInExpression(rightExpression, ref rightInValues);
}

if (leftColumnExpression != null
Expand Down Expand Up @@ -118,7 +116,7 @@ var rightColumnExpression
return binaryExpression.Update(leftExpression, binaryExpression.Conversion, rightExpression);
}

private static ColumnExpression MatchEqualityExpression(
private static Expression MatchEqualityExpression(
Expression expression,
ExpressionType equalityType,
out Expression nonColumnExpression)
Expand All @@ -129,16 +127,16 @@ private static ColumnExpression MatchEqualityExpression(

if (binaryExpression?.NodeType == equalityType)
{
nonColumnExpression
= binaryExpression.Right as ConstantExpression
?? binaryExpression.Right as ParameterExpression
?? (Expression)(binaryExpression.Left as ConstantExpression)
?? binaryExpression.Left as ParameterExpression;
var left = binaryExpression.Left;
var right = binaryExpression.Right;

if (nonColumnExpression != null)
var isLeftConstantOrParameter = left is ConstantExpression || left is ParameterExpression;

if (isLeftConstantOrParameter || right is ConstantExpression || right is ParameterExpression)
{
return binaryExpression.Right.TryGetColumnExpression()
?? binaryExpression.Left.TryGetColumnExpression();
nonColumnExpression = isLeftConstantOrParameter ? left : right;

return isLeftConstantOrParameter ? right : left;
}
}

Expand All @@ -149,9 +147,7 @@ private static Expression MatchInExpression(
Expression expression,
ref IReadOnlyList<Expression> values)
{
var inExpression = expression as InExpression;

if (inExpression != null)
if (expression is InExpression inExpression)
{
values = inExpression.Values;

Expand All @@ -167,8 +163,8 @@ private static Expression MatchNotInExpression(
{
var unaryExpression = expression as UnaryExpression;

return (unaryExpression != null)
&& (unaryExpression.NodeType == ExpressionType.Not)
return unaryExpression != null
&& unaryExpression.NodeType == ExpressionType.Not
? MatchInExpression(unaryExpression.Operand, ref values)
: null;
}
Expand Down
Loading

0 comments on commit bb66f16

Please sign in to comment.