Skip to content

Commit

Permalink
Normalize Any to Contains instead of vice versa (#30956)
Browse files Browse the repository at this point in the history
E.g. for easier pattern-matching of Contains
  • Loading branch information
roji authored Jun 1, 2023
1 parent e5b7c93 commit 30d689e
Show file tree
Hide file tree
Showing 29 changed files with 446 additions and 418 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,28 +236,12 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy
/// </summary>
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;
var newItem = TranslateExpression(item, preserveType: true);
if (newItem == null)
{
return null;
}

item = newItem;
var anyLambdaParameter = Expression.Parameter(item.Type, "p");
var anyLambda = Expression.Lambda(
ExpressionExtensions.CreateEqualsExpression(anyLambdaParameter, item),
anyLambdaParameter);

inMemoryQueryExpression.UpdateServerQueryExpression(
Expression.Call(
EnumerableMethods.Contains.MakeGenericMethod(item.Type),
Expression.Call(
EnumerableMethods.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, item.Type),
inMemoryQueryExpression.ServerQueryExpression,
Expression.Lambda(
inMemoryQueryExpression.GetProjection(
new ProjectionBindingExpression(inMemoryQueryExpression, new ProjectionMember(), item.Type)),
inMemoryQueryExpression.CurrentParameter)),
item));

return source.UpdateShaperExpression(Expression.Convert(inMemoryQueryExpression.GetSingleScalarProjection(), typeof(bool)));
return TranslateAny(source, anyLambda);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,33 +411,28 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
return null;
}

var selectExpression = (SelectExpression)source.QueryExpression;
var subquery = (SelectExpression)source.QueryExpression;

// Negate the predicate, unless it's already negated, in which case remove that.
selectExpression.ApplyPredicate(
subquery.ApplyPredicate(
translation is SqlUnaryExpression { OperatorType: ExpressionType.Not, Operand: var nestedOperand }
? nestedOperand
: _sqlExpressionFactory.Not(translation));

if (TrySimplifyValuesToInExpression(source, isNegated: true, out var simplifiedQuery))
{
return simplifiedQuery;
}

selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();
if (selectExpression.Limit == null
&& selectExpression.Offset == null)
subquery.ReplaceProjection(new List<Expression>());
subquery.ApplyProjection();
if (subquery.Limit == null
&& subquery.Offset == null)
{
selectExpression.ClearOrdering();
subquery.ClearOrdering();
}

translation = _sqlExpressionFactory.Exists(selectExpression, true);
selectExpression = _sqlExpressionFactory.Select(translation);
translation = _sqlExpressionFactory.Exists(subquery, true);
subquery = _sqlExpressionFactory.Select(translation);

return source.Update(
selectExpression,
Expression.Convert(new ProjectionBindingExpression(selectExpression, new ProjectionMember(), typeof(bool?)), typeof(bool)));
subquery,
Expression.Convert(new ProjectionBindingExpression(subquery, new ProjectionMember(), typeof(bool?)), typeof(bool)));
}

/// <inheritdoc />
Expand All @@ -452,24 +447,19 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
}

source = translatedSource;

if (TrySimplifyValuesToInExpression(source, isNegated: false, out var simplifiedQuery))
{
return simplifiedQuery;
}
}

var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();
if (selectExpression.Limit == null
&& selectExpression.Offset == null)
var subquery = (SelectExpression)source.QueryExpression;
subquery.ReplaceProjection(new List<Expression>());
subquery.ApplyProjection();
if (subquery.Limit == null
&& subquery.Offset == null)
{
selectExpression.ClearOrdering();
subquery.ClearOrdering();
}

var translation = _sqlExpressionFactory.Exists(selectExpression, false);
selectExpression = _sqlExpressionFactory.Select(translation);
var translation = _sqlExpressionFactory.Exists(subquery, false);
var selectExpression = _sqlExpressionFactory.Select(translation);

return source.Update(
selectExpression,
Expand Down Expand Up @@ -501,47 +491,65 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
/// <inheritdoc />
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
var selectExpression = (SelectExpression)source.QueryExpression;
var translation = TranslateExpression(item);
if (translation == null)
{
return null;
}

if (selectExpression.Limit == null
&& selectExpression.Offset == null)
{
selectExpression.ClearOrdering();
}

var shaperExpression = source.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression)
// Pattern-match Contains over ValuesExpression, translating to simplified 'item IN (1, 2, 3)' with constant elements
if (source.QueryExpression is SelectExpression
{
Tables:
[
ValuesExpression
{
RowValues: [{ Values.Count: 2 }, ..],
ColumnNames: [ValuesOrderingColumnName, ValuesValueColumnName]
} valuesExpression
],
Predicate: null,
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
// Note that in the context of Contains we don't care about orderings
}
// Make sure that the source projects the column from the ValuesExpression directly, i.e. no projection out with some expression
&& TryGetProjection(source, out var projection)
&& projection is ColumnExpression projectedColumn
&& projectedColumn.Table == valuesExpression)
{
var projection = selectExpression.GetProjection(projectionBindingExpression);
if (projection is SqlExpression sqlExpression)
if (TranslateExpression(item) is not SqlExpression translatedItem)
{
selectExpression.ReplaceProjection(new List<Expression> { sqlExpression });
selectExpression.ApplyProjection();
return null;
}

translation = _sqlExpressionFactory.In(translation, selectExpression, false);
selectExpression = _sqlExpressionFactory.Select(translation);
var values = new object?[valuesExpression.RowValues.Count];
for (var i = 0; i < values.Length; i++)
{
// Skip the first value (_ord), which is irrelevant for Contains
if (valuesExpression.RowValues[i].Values[1] is SqlConstantExpression { Value: var constantValue })
{
values[i] = constantValue;
}
else
{
// We only support constants for now
values = null;
break;
}
}

return source.Update(
selectExpression,
Expression.Convert(
new ProjectionBindingExpression(selectExpression, new ProjectionMember(), typeof(bool?)), typeof(bool)));
if (values is not null)
{
var inExpression = _sqlExpressionFactory.In(translatedItem, _sqlExpressionFactory.Constant(values), negated: false);
return source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression);
}
}

return null;
// TODO: This generates an EXISTS subquery. Translate to IN instead: #30955
var anyLambdaParameter = Expression.Parameter(item.Type, "p");
var anyLambda = Expression.Lambda(
Infrastructure.ExpressionExtensions.CreateEqualsExpression(anyLambdaParameter, item),
anyLambdaParameter);

return TranslateAny(source, anyLambda);
}

/// <inheritdoc />
Expand Down Expand Up @@ -1812,89 +1820,6 @@ protected virtual Expression ApplyInferredTypeMappings(
protected virtual bool IsOrdered(SelectExpression selectExpression)
=> selectExpression.Orderings.Count > 0;

/// <summary>
/// Attempts to pattern-match for Contains over <see cref="ValuesExpression" />, which corresponds to
/// <c>Where(b => new[] { 1, 2, 3 }.Contains(b.Id))</c>. Simplifies this to the tighter <c>[b].[Id] IN (1, 2, 3)</c> instead of the
/// full subquery with VALUES.
/// </summary>
private bool TrySimplifyValuesToInExpression(
ShapedQueryExpression source,
bool isNegated,
[NotNullWhen(true)] out ShapedQueryExpression? simplifiedQuery)
{
if (source.QueryExpression is SelectExpression
{
Tables: [ValuesExpression
{
RowValues: [{ Values.Count: 2 }, ..],
ColumnNames: [ ValuesOrderingColumnName, ValuesValueColumnName ]
} valuesExpression],
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
// Note that we don't care about orderings, they get elided anyway by Any/All
Predicate: SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: var left, Right: var right },
} selectExpression)
{
// The table is a ValuesExpression, and the predicate is an equality - this is a possible simplifiable Contains.
// Get the projection column pointing to the ValuesExpression, and check that it's compared to on one side of the predicate
// equality.
var shaperExpression = source.ShaperExpression;
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is ColumnExpression projectionColumn)
{
SqlExpression item;

if (left is ColumnExpression leftColumn
&& (leftColumn.Table, leftColumn.Name) == (projectionColumn.Table, projectionColumn.Name))
{
item = right;
}
else if (right is ColumnExpression rightColumn
&& (rightColumn.Table, rightColumn.Name) == (projectionColumn.Table, projectionColumn.Name))
{
item = left;
}
else
{
simplifiedQuery = null;
return false;
}

var values = new object?[valuesExpression.RowValues.Count];
for (var i = 0; i < values.Length; i++)
{
// Skip the first value (_ord), which is irrelevant for Contains
if (valuesExpression.RowValues[i].Values[1] is SqlConstantExpression { Value: var constantValue })
{
values[i] = constantValue;
}
else
{
simplifiedQuery = null;
return false;
}
}

var inExpression = _sqlExpressionFactory.In(item, _sqlExpressionFactory.Constant(values), isNegated);
simplifiedQuery = source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression);
return true;
}
}

simplifiedQuery = null;
return false;
}

private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression)
{
var lambdaBody = ReplacingExpressionVisitor.Replace(
Expand Down Expand Up @@ -2569,6 +2494,29 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
return source.UpdateShaperExpression(shaper);
}

private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
{
var shaperExpression = shapedQueryExpression.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
{
projection = sqlExpression;
return true;
}

projection = null;
return false;
}

/// <summary>
/// A visitor which scans an expression tree and attempts to find columns for which we were missing type mappings (projected out
/// of queryable constant/parameter), and those type mappings have been inferred.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1133,9 +1133,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
var operand = Visit(unaryExpression.Operand);

if (operand is EntityReferenceExpression entityReferenceExpression
&& (unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked
|| unaryExpression.NodeType == ExpressionType.TypeAs))
&& unaryExpression.NodeType is ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs)
{
return entityReferenceExpression.Convert(unaryExpression.Type);
}
Expand All @@ -1148,7 +1146,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
switch (unaryExpression.NodeType)
{
case ExpressionType.Not:
return _sqlExpressionFactory.Not(sqlOperand!);
return sqlOperand switch
{
ExistsExpression e => e.Negate(),
InExpression e => e.Negate(),

_ => _sqlExpressionFactory.Not(sqlOperand!)
};

case ExpressionType.Negate:
case ExpressionType.NegateChecked:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ public ExistsExpression(
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> Update((SelectExpression)visitor.Visit(Subquery));

/// <summary>
/// Negates this expression by changing presence/absence state indicated by <see cref="IsNegated" />.
/// </summary>
/// <returns>An expression which is negated form of this expression.</returns>
public virtual ExistsExpression Negate()
=> new(Subquery, !IsNegated, TypeMapping);

/// <summary>
/// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
/// return this expression.
Expand Down
Loading

0 comments on commit 30d689e

Please sign in to comment.