Skip to content

Commit

Permalink
Fix to #17066 - Query: Translate ElementAt(OrDefault)
Browse files Browse the repository at this point in the history
Enable member pushdown for ElementAt(OrDefault) and then handle the method in translation in a similar way to First(OrDefault) with an optional Skip(x) element.

Fixes #17066
  • Loading branch information
maumar committed Nov 10, 2022
1 parent d4aa1ea commit 1a245a1
Show file tree
Hide file tree
Showing 34 changed files with 1,008 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,24 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
ShapedQueryExpression source,
Expression index,
bool returnDefault)
=> null;
{
var selectExpression = (SelectExpression)source.QueryExpression;
var translation = TranslateExpression(index);
if (translation == null)
{
return null;
}

if (selectExpression.Orderings.Count == 0)
{
_queryCompilationContext.Logger.RowLimitingOperationWithoutOrderByWarning();
}

selectExpression.ApplyOffset(translation);
selectExpression.ApplyLimit(TranslateExpression(Expression.Constant(1))!);

return source;
}

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateExcept(ShapedQueryExpression source1, ShapedQueryExpression source2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor
QueryableMethods.LastWithPredicate,
QueryableMethods.LastWithoutPredicate,
QueryableMethods.LastOrDefaultWithPredicate,
QueryableMethods.LastOrDefaultWithoutPredicate
//QueryableMethodProvider.ElementAtMethodInfo,
//QueryableMethodProvider.ElementAtOrDefaultMethodInfo
QueryableMethods.LastOrDefaultWithoutPredicate,
QueryableMethods.ElementAt,
QueryableMethods.ElementAtOrDefault
};

private static readonly List<MethodInfo> PredicateAggregateMethodInfos = new()
Expand Down Expand Up @@ -346,13 +346,27 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
&& SingleResultMethodInfos.Contains(nonNullMethodCallExpression.Method.GetGenericMethodDefinition()))
{
var source = nonNullMethodCallExpression.Arguments[0];
if (nonNullMethodCallExpression.Arguments.Count == 2)
var genericMethod = nonNullMethodCallExpression.Method.GetGenericMethodDefinition();
if (genericMethod == QueryableMethods.FirstWithPredicate
|| genericMethod == QueryableMethods.FirstOrDefaultWithPredicate
|| genericMethod == QueryableMethods.SingleWithPredicate
|| genericMethod == QueryableMethods.SingleOrDefaultWithPredicate
|| genericMethod == QueryableMethods.LastWithPredicate
|| genericMethod == QueryableMethods.LastOrDefaultWithPredicate)
{
source = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(source.Type.GetSequenceType()),
source,
nonNullMethodCallExpression.Arguments[1]);
}
else if (genericMethod == QueryableMethods.ElementAt
|| genericMethod == QueryableMethods.ElementAtOrDefault)
{
source = Expression.Call(
QueryableMethods.Skip.MakeGenericMethod(source.Type.GetSequenceType()),
source,
nonNullMethodCallExpression.Arguments[1]);
}

var translatedSubquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(source);
if (translatedSubquery != null)
Expand Down
81 changes: 81 additions & 0 deletions src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,87 @@ public static Task<long> LongCountAsync<TSource>(

#endregion

#region ElementAt/ElementAtOrDefault

/// <summary>
/// Asynchronously returns the element at a specified index in a sequence.
/// </summary>
/// <remarks>
/// <para>
/// Multiple active operations on the same context instance are not supported. Use <see langword="await" /> to ensure
/// that any asynchronous operations have completed before calling another method on this context.
/// See <see href="https://aka.ms/efcore-docs-threading">Avoiding DbContext threading issues</see> for more information and examples.
/// </para>
/// <para>
/// See <see href="https://aka.ms/efcore-docs-async-linq">Querying data with EF Core</see> for more information and examples.
/// </para>
/// </remarks>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <param name="source">An <see cref="IQueryable{T}" /> to return the element from.</param>
/// <param name="index">The zero-based index of the element to retrieve.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <returns>
/// A task that represents the asynchronous operation.
/// The task result contains the element at a specified index in a <paramref name="source" /> sequence.
/// </returns>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <exception cref="ArgumentOutOfRangeException">
/// <para>
/// <paramref name="index" /> is less than zero.
/// </para>
/// </exception>
/// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
public static Task<TSource> ElementAtAsync<TSource>(
this IQueryable<TSource> source,
int index,
CancellationToken cancellationToken = default)
{
Check.NotNull(index, nameof(index));

return ExecuteAsync<TSource, Task<TSource>>(
QueryableMethods.ElementAt, source, Expression.Constant(index), cancellationToken);
}

/// <summary>
/// Asynchronously returns the element at a specified index in a sequence, or a default value if the index is out of range.
/// </summary>
/// <remarks>
/// <para>
/// Multiple active operations on the same context instance are not supported. Use <see langword="await" /> to ensure
/// that any asynchronous operations have completed before calling another method on this context.
/// See <see href="https://aka.ms/efcore-docs-threading">Avoiding DbContext threading issues</see> for more information and examples.
/// </para>
/// <para>
/// See <see href="https://aka.ms/efcore-docs-async-linq">Querying data with EF Core</see> for more information and examples.
/// </para>
/// </remarks>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <param name="source">An <see cref="IQueryable{T}" /> to return the element from.</param>
/// <param name="index">The zero-based index of the element to retrieve.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <returns>
/// A task that represents the asynchronous operation.
/// The task result contains the element at a specified index in a <paramref name="source" /> sequence.
/// </returns>
/// <exception cref="ArgumentNullException">
/// <paramref name="source" /> is <see langword="null" />.
/// </exception>
/// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
public static Task<TSource> ElementAtOrDefaultAsync<TSource>(
this IQueryable<TSource> source,
int index,
CancellationToken cancellationToken = default)
{
Check.NotNull(index, nameof(index));

return ExecuteAsync<TSource, Task<TSource>>(
QueryableMethods.ElementAtOrDefault, source, Expression.Constant(index), cancellationToken);
}

#endregion

#region First/FirstOrDefault

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1000,10 +1000,13 @@ private sealed class ReducingExpressionVisitor : ExpressionVisitor

if (navigationExpansionExpression.CardinalityReducingGenericMethodInfo != null)
{
var arguments = new List<Expression> { result };
arguments.AddRange(navigationExpansionExpression.CardinalityReducingMethodArguments.Select(x => Visit(x)));

result = Expression.Call(
navigationExpansionExpression.CardinalityReducingGenericMethodInfo.MakeGenericMethod(
result.Type.GetSequenceType()),
result);
arguments.ToArray());
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ private set

public Expression PendingSelector { get; private set; }
public MethodInfo? CardinalityReducingGenericMethodInfo { get; private set; }
public List<Expression> CardinalityReducingMethodArguments { get; private set; } = new();

public Type SourceElementType
=> CurrentParameter.Type;
Expand Down Expand Up @@ -274,8 +275,11 @@ public void AppendPendingOrdering(MethodInfo orderingMethod, Expression keySelec
public void ClearPendingOrderings()
=> _pendingOrderings.Clear();

public void ConvertToSingleResult(MethodInfo genericMethod)
=> CardinalityReducingGenericMethodInfo = genericMethod;
public void ConvertToSingleResult(MethodInfo genericMethod, params Expression[] arguments)
{
CardinalityReducingGenericMethodInfo = genericMethod;
CardinalityReducingMethodArguments.AddRange(arguments);
}

public override ExpressionType NodeType
=> ExpressionType.Extension;
Expand Down
26 changes: 26 additions & 0 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,16 @@ when QueryableMethods.IsSumWithSelector(method):
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
methodCallExpression.Type);

case nameof(Queryable.ElementAt)
when genericMethod == QueryableMethods.ElementAt:
case nameof(Queryable.ElementAtOrDefault)
when genericMethod == QueryableMethods.ElementAtOrDefault:
return ProcessElementAt(
source,
genericMethod,
methodCallExpression.Arguments[1],
methodCallExpression.Type);

case nameof(Queryable.Join)
when genericMethod == QueryableMethods.Join:
{
Expand Down Expand Up @@ -982,6 +992,22 @@ private NavigationExpansionExpression ProcessFirstSingleLastOrDefault(
return source;
}

private NavigationExpansionExpression ProcessElementAt(
NavigationExpansionExpression source,
MethodInfo genericMethod,
Expression index,
Type returnType)
{
if (source.PendingSelector.Type != returnType)
{
source.ApplySelector(Expression.Convert(source.PendingSelector, returnType));
}

source.ConvertToSingleResult(genericMethod, index);

return source;
}

// This returns Expression since it can also return a deferred GroupBy operation
private Expression ProcessGroupBy(
NavigationExpansionExpression source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public class SubqueryMemberPushdownExpressionVisitor : ExpressionVisitor
QueryableMethods.LastWithPredicate,
QueryableMethods.LastWithoutPredicate,
QueryableMethods.LastOrDefaultWithPredicate,
QueryableMethods.LastOrDefaultWithoutPredicate
//QueryableMethodProvider.ElementAtMethodInfo,
//QueryableMethodProvider.ElementAtOrDefaultMethodInfo
QueryableMethods.LastOrDefaultWithoutPredicate,
QueryableMethods.ElementAt,
QueryableMethods.ElementAtOrDefault
};

private static readonly IDictionary<MethodInfo, MethodInfo> PredicateLessMethodInfo = new Dictionary<MethodInfo, MethodInfo>
Expand Down Expand Up @@ -163,7 +163,13 @@ private Expression PushdownMember(
var source = methodCallExpression.Arguments[0];
var queryableType = source.Type.GetSequenceType();
var genericMethod = methodCallExpression.Method.GetGenericMethodDefinition();
if (methodCallExpression.Arguments.Count == 2)

if (genericMethod == QueryableMethods.FirstWithPredicate
|| genericMethod == QueryableMethods.FirstOrDefaultWithPredicate
|| genericMethod == QueryableMethods.SingleWithPredicate
|| genericMethod == QueryableMethods.SingleOrDefaultWithPredicate
|| genericMethod == QueryableMethods.LastWithPredicate
|| genericMethod == QueryableMethods.LastOrDefaultWithPredicate)
{
// Move predicate to Where so that we can change shape before operator
source = Expression.Call(
Expand Down Expand Up @@ -203,7 +209,16 @@ private Expression PushdownMember(
Expression.Quote(Expression.Lambda(memberAccessExpression, parameter)));
}

source = Expression.Call(genericMethod.MakeGenericMethod(source.Type.GetSequenceType()), source);
if (genericMethod == QueryableMethods.ElementAt
|| genericMethod == QueryableMethods.ElementAtOrDefault)
{
var index = Visit(methodCallExpression.Arguments[1]);
source = Expression.Call(genericMethod.MakeGenericMethod(source.Type.GetSequenceType()), source, index);
}
else
{
source = Expression.Call(genericMethod.MakeGenericMethod(source.Type.GetSequenceType()), source);
}

return source.Type != returnType
? Expression.Convert(source, returnType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4584,6 +4584,27 @@ public override async Task Select_subquery_recursive_trivial_returning_queryable
AssertSql();
}

public override async Task Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_constant_zero(bool async)
{
await AssertTranslationFailed(() => base.Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_constant_zero(async));

AssertSql();
}

public override async Task Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_constant_one(bool async)
{
await AssertTranslationFailed(() => base.Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_constant_one(async));

AssertSql();
}

public override async Task Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_parameter(bool async)
{
await AssertTranslationFailed(() => base.Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_parameter(async));

AssertSql();
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2413,6 +2413,22 @@ public override async Task First_over_custom_projection_compared_to_not_null(boo
AssertSql();
}

public override async Task ElementAt_over_custom_projection_compared_to_not_null(bool async)
{
// Cosmos client evaluation. Issue #17246.
await AssertTranslationFailed(() => base.ElementAt_over_custom_projection_compared_to_not_null(async));

AssertSql();
}

public override async Task ElementAtOrDefault_over_custom_projection_compared_to_null(bool async)
{
// Cosmos client evaluation. Issue #17246.
await AssertTranslationFailed(() => base.ElementAtOrDefault_over_custom_projection_compared_to_null(async));

AssertSql();
}

public override async Task Single_over_custom_projection_compared_to_null(bool async)
{
// Cosmos client evaluation. Issue #17246.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,19 @@ public override Task Null_semantics_is_correctly_applied_for_function_comparison
// Null protection. Issue #13721.
=> Assert.ThrowsAsync<InvalidOperationException>(
() => base.Null_semantics_is_correctly_applied_for_function_comparisons_that_take_arguments_from_optional_navigation(async));

public override Task ElementAt_basic_with_OrderBy(bool async)
=> Task.CompletedTask;

public override Task ElementAtOrDefault_basic_with_OrderBy(bool async)
=> Task.CompletedTask;

public override Task ElementAtOrDefault_basic_with_OrderBy_parameter(bool async)
=> Task.CompletedTask;

public override Task Where_subquery_with_ElementAtOrDefault_equality_to_null_with_composite_key(bool async)
=> Task.CompletedTask;

public override Task Where_subquery_with_ElementAt_using_column_as_index(bool async)
=> Task.CompletedTask;
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@ public override async Task Entity_equality_through_subquery_composite_key(bool a
CoreStrings.EntityEqualityOnCompositeKeyEntitySubqueryNotSupported("==", nameof(OrderDetail)),
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Entity_equality_through_subquery_composite_key(async))).Message);

public override Task Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_constant_zero(bool async)
=> Task.CompletedTask;

public override Task Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_constant_one(bool async)
=> Task.CompletedTask;

public override Task Collection_navigation_equal_to_null_for_subquery_using_ElementAtOrDefault_parameter(bool async)
=> Task.CompletedTask;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@ public override async Task<string> Where_simple_closure(bool async)
public override Task Like_with_non_string_column_using_double_cast(bool async)
// Casting int to object to string is invalid for InMemory
=> Assert.ThrowsAsync<InvalidCastException>(() => base.Like_with_non_string_column_using_double_cast(async));

public override Task ElementAt_over_custom_projection_compared_to_not_null(bool async)
=> Task.CompletedTask;

public override Task ElementAtOrDefault_over_custom_projection_compared_to_null(bool async)
=> Task.CompletedTask;
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ protected ComplexNavigationsCollectionsSplitQueryRelationalTestBase(TFixture fix
}

protected override Expression RewriteServerQueryExpression(Expression serverQueryExpression)
=> new SplitQueryRewritingExpressionVisitor().Visit(serverQueryExpression);
{
serverQueryExpression = base.RewriteServerQueryExpression(serverQueryExpression);

return new SplitQueryRewritingExpressionVisitor().Visit(serverQueryExpression);
}

private class SplitQueryRewritingExpressionVisitor : ExpressionVisitor
{
Expand Down
Loading

0 comments on commit 1a245a1

Please sign in to comment.