Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InMemory: Support for non-scalar single result in projection #17405

Merged
merged 1 commit into from
Aug 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ private static IEnumerable<MethodInfo> GetMethods(string name, int parameterCoun
public static MethodInfo SinglePredicate = GetMethod(nameof(Enumerable.Single), 1);
public static MethodInfo SingleOrDefaultPredicate = GetMethod(nameof(Enumerable.SingleOrDefault), 1);

public static MethodInfo First = GetMethod(nameof(Enumerable.First), 0);
public static MethodInfo FirstOrDefault = GetMethod(nameof(Enumerable.FirstOrDefault), 0);
public static MethodInfo Last = GetMethod(nameof(Enumerable.Last), 0);
public static MethodInfo LastOrDefault = GetMethod(nameof(Enumerable.LastOrDefault), 0);
public static MethodInfo Single = GetMethod(nameof(Enumerable.Single), 0);
public static MethodInfo SingleOrDefault = GetMethod(nameof(Enumerable.SingleOrDefault), 0);

public static MethodInfo GetAggregateMethod(string methodName, Type elementType, int parameterCount = 0)
{
Check.NotEmpty(methodName, nameof(methodName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,13 @@ public override Expression Visit(Expression expression)
return AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type);
}

throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
return new SingleResultShaperExpression(
new ProjectionBindingExpression(
_queryExpression,
_queryExpression.AddSubqueryProjection(subquery, out var innerShaper),
typeof(ValueBuffer)),
innerShaper,
subquery.ShaperExpression.Type);
}

break;
Expand Down
12 changes: 11 additions & 1 deletion src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,23 @@ public virtual int AddSubqueryProjection(ShapedQueryExpression shapedQueryExpres
{
var subquery = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression;
subquery.ApplyProjection();
var serverQueryExpression = subquery.ServerQueryExpression;

if (serverQueryExpression is MethodCallExpression selectMethodCall
&& selectMethodCall.Arguments[0].Type == typeof(ResultEnumerable))
{
var terminatingMethodCall = (MethodCallExpression)((LambdaExpression)((NewExpression)selectMethodCall.Arguments[0]).Arguments[0]).Body;
selectMethodCall = selectMethodCall.Update(
null, new[] { terminatingMethodCall.Arguments[0], selectMethodCall.Arguments[1] });
serverQueryExpression = terminatingMethodCall.Update(null, new[] { selectMethodCall });
}

innerShaper = new ShaperRemappingExpressionVisitor(subquery._projectionMapping)
.Visit(shapedQueryExpression.ShaperExpression);

innerShaper = Lambda(innerShaper, subquery.ValueBufferParameter);

return AddToProjection(subquery.ServerQueryExpression);
return AddToProjection(serverQueryExpression);
}

private class ShaperRemappingExpressionVisitor : ExpressionVisitor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ protected override ShapedQueryExpression TranslateFirstOrDefault(ShapedQueryExpr
predicate,
returnType,
returnDefault
? InMemoryLinqOperatorProvider.FirstOrDefaultPredicate
: InMemoryLinqOperatorProvider.FirstPredicate);
? InMemoryLinqOperatorProvider.FirstOrDefault
: InMemoryLinqOperatorProvider.First);
}

protected override ShapedQueryExpression TranslateGroupBy(ShapedQueryExpression source, LambdaExpression keySelector, LambdaExpression elementSelector, LambdaExpression resultSelector)
Expand Down Expand Up @@ -258,8 +258,8 @@ protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpre
predicate,
returnType,
returnDefault
? InMemoryLinqOperatorProvider.LastOrDefaultPredicate
: InMemoryLinqOperatorProvider.LastPredicate);
? InMemoryLinqOperatorProvider.LastOrDefault
: InMemoryLinqOperatorProvider.Last);
}

protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector)
Expand Down Expand Up @@ -521,8 +521,8 @@ protected override ShapedQueryExpression TranslateSingleOrDefault(ShapedQueryExp
predicate,
returnType,
returnDefault
? InMemoryLinqOperatorProvider.SingleOrDefaultPredicate
: InMemoryLinqOperatorProvider.SinglePredicate);
? InMemoryLinqOperatorProvider.SingleOrDefault
: InMemoryLinqOperatorProvider.Single);
}

protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression source, Expression count)
Expand Down Expand Up @@ -662,20 +662,19 @@ private ShapedQueryExpression TranslateSingleResultOperator(
{
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;

predicate = predicate == null
? Expression.Lambda(Expression.Constant(true), Expression.Parameter(typeof(ValueBuffer)))
: TranslateLambdaExpression(source, predicate);

if (predicate == null)
if (predicate != null)
{
return null;
source = TranslateWhere(source, predicate);
if (source == null)
{
return null;
}
}

inMemoryQueryExpression.ServerQueryExpression =
Expression.Call(
method.MakeGenericMethod(typeof(ValueBuffer)),
inMemoryQueryExpression.ServerQueryExpression,
predicate);
inMemoryQueryExpression.ServerQueryExpression);

inMemoryQueryExpression.ConvertToEnumerable();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -36,6 +37,10 @@ private static readonly MethodInfo _materializeCollectionMethodInfo
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(MaterializeCollection));

private static readonly MethodInfo _materializeSingleResultMethodInfo
= typeof(CustomShaperCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(MaterializeSingleResult));

private static void SetIsLoadedNoTracking(object entity, INavigation navigation)
=> ((ILazyLoader)(navigation
.DeclaringEntityType
Expand Down Expand Up @@ -144,6 +149,14 @@ private static TCollection MaterializeCollection<TElement, TCollection>(
return collection;
}

private static TResult MaterializeSingleResult<TResult>(
QueryContext queryContext,
ValueBuffer valueBuffer,
Func<QueryContext, ValueBuffer, TResult> innerShaper)
=> valueBuffer.IsEmpty
? default
: innerShaper(queryContext, valueBuffer);

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is IncludeExpression includeExpression)
Expand Down Expand Up @@ -203,6 +216,15 @@ protected override Expression VisitExtension(Expression extensionExpression)
typeof(IClrCollectionAccessor)));
}

if (extensionExpression is SingleResultShaperExpression singleResultShaperExpression)
{
return Expression.Call(
_materializeSingleResultMethodInfo.MakeGenericMethod(singleResultShaperExpression.Type),
QueryCompilationContext.QueryContextParameter,
singleResultShaperExpression.Projection,
Expression.Constant(((LambdaExpression)Visit(singleResultShaperExpression.InnerShaper)).Compile()));
}

return base.VisitExtension(extensionExpression);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ protected override Expression VisitExtension(Expression extensionExpression)

return variable;
}

case SingleResultShaperExpression singleResultShaperExpression:
{
var key = GenerateKey((ProjectionBindingExpression)singleResultShaperExpression.Projection);
if (!_mapping.TryGetValue(key, out var variable))
{
var projection = Visit(singleResultShaperExpression.Projection);

variable = Expression.Parameter(singleResultShaperExpression.Type);
_variables.Add(variable);

var innerLambda = (LambdaExpression)singleResultShaperExpression.InnerShaper;
var innerShaper = new ShaperExpressionProcessingExpressionVisitor(null, innerLambda.Parameters[0])
.Inject(innerLambda.Body);

_expressions.Add(Expression.Assign(variable, singleResultShaperExpression.Update(projection, innerShaper)));
_mapping[key] = variable;
}

return variable;
}
}

return base.VisitExtension(extensionExpression);
Expand Down
54 changes: 54 additions & 0 deletions src/EFCore.InMemory/Query/Internal/SingleResultShaperExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class SingleResultShaperExpression : Expression, IPrintableExpression
{
public SingleResultShaperExpression(
Expression projection,
Expression innerShaper,
Type type)
{
Projection = projection;
InnerShaper = innerShaper;
Type = type;
}

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var projection = visitor.Visit(Projection);
var innerShaper = visitor.Visit(InnerShaper);

return Update(projection, innerShaper);
}

public virtual SingleResultShaperExpression Update(Expression projection, Expression innerShaper)
=> projection != Projection || innerShaper != InnerShaper
? new SingleResultShaperExpression(projection, innerShaper, Type)
: this;

public sealed override ExpressionType NodeType => ExpressionType.Extension;
public override Type Type { get; }

public virtual Expression Projection { get; }
public virtual Expression InnerShaper { get; }

public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine($"{nameof(SingleResultShaperExpression)}:");
using (expressionPrinter.Indent())
{
expressionPrinter.Append("(");
expressionPrinter.Visit(Projection);
expressionPrinter.Append(", ");
expressionPrinter.Visit(InnerShaper);
expressionPrinter.AppendLine($")");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,6 @@ public QueryNavigationsInMemoryTest(
//TestLoggerFactory.TestOutputHelper = testOutputHelper;
}

#region SingleResultProjection

public override Task Collection_select_nav_prop_first_or_default(bool isAsync)
{
return Task.CompletedTask;
}

public override Task Collection_select_nav_prop_first_or_default_then_nav_prop(bool isAsync)
{
return Task.CompletedTask;
}

public override Task Project_single_entity_value_subquery_works(bool isAsync)
{
return Task.CompletedTask;
}

public override Task Select_collection_FirstOrDefault_project_anonymous_type(bool isAsync)
{
return Task.CompletedTask;
}

public override Task Select_collection_FirstOrDefault_project_entity(bool isAsync)
{
return Task.CompletedTask;
}

public override Task Skip_Select_Navigation(bool isAsync)
{
return Task.CompletedTask;
}

public override Task Take_Select_Navigation(bool isAsync)
{
return Task.CompletedTask;
}

#endregion

[ConditionalTheory(Skip = "Issue#17386")]
public override Task Where_subquery_on_navigation_client_eval(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,6 @@ public override Task Select_DTO_with_member_init_distinct_in_subquery_translated

#endregion

#region SingleResultProjection
public override Task Project_single_element_from_collection_with_OrderBy_over_navigation_Take_and_FirstOrDefault_2(bool isAsync)
{
return Task.CompletedTask;
}
#endregion

#region NullableError

public override Task Project_single_element_from_collection_with_OrderBy_Distinct_and_FirstOrDefault_followed_by_projecting_length(bool isAsync)
Expand Down