Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InMemory: Support for DefaultIfEmpty #17484

Merged
merged 3 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class EntityProjectionExpression : Expression
public class EntityProjectionExpression : Expression, IPrintableExpression
{
private readonly IDictionary<IProperty, Expression> _readExpressionMap;

Expand Down Expand Up @@ -50,5 +51,19 @@ public virtual Expression BindProperty(IProperty property)

return _readExpressionMap[property];
}

public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(EntityProjectionExpression) + ":");
using (expressionPrinter.Indent())
{
foreach (var readExpressionMapEntry in _readExpressionMap)
{
expressionPrinter.Append(readExpressionMapEntry.Key + " -> ");
expressionPrinter.Visit(readExpressionMapEntry.Value);
expressionPrinter.AppendLine();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -69,27 +70,54 @@ public virtual Expression Translate(Expression expression)

protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var left = Visit(binaryExpression.Left);
var right = Visit(binaryExpression.Right);
if (left == null || right == null)
var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

if (newLeft == null || newRight == null)
{
return null;
}

return binaryExpression.Update(left, binaryExpression.Conversion, right);
if (IsConvertedToNullable(newLeft, binaryExpression.Left)
|| IsConvertedToNullable(newRight, binaryExpression.Right))
{
newLeft = ConvertToNullable(newLeft);
newRight = ConvertToNullable(newRight);
}

return Expression.MakeBinary(
binaryExpression.NodeType,
newLeft,
newRight,
binaryExpression.IsLiftedToNull,
binaryExpression.Method,
binaryExpression.Conversion);
}

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
var test = Visit(conditionalExpression.Test);
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (test == null || ifTrue == null || ifFalse == null)
{
return null;
}

return conditionalExpression.Update(test, ifTrue, ifFalse);
if (test.Type == typeof(bool?))
{
test = Expression.Equal(test, Expression.Constant(true, typeof(bool?)));
}

if (IsConvertedToNullable(ifTrue, conditionalExpression.IfTrue)
|| IsConvertedToNullable(ifFalse, conditionalExpression.IfFalse))
{
ifTrue = ConvertToNullable(ifTrue);
ifFalse = ConvertToNullable(ifFalse);
}

return Expression.Condition(test, ifTrue, ifFalse);
}

protected override Expression VisitMember(MemberExpression memberExpression)
Expand All @@ -109,7 +137,25 @@ protected override Expression VisitMember(MemberExpression memberExpression)
return result;
}

return memberExpression.Update(innerExpression);
static bool shouldApplyNullProtectionForMemberAccess(Type callerType, string memberName)
=> !(callerType.IsGenericType
&& callerType.GetGenericTypeDefinition() == typeof(Nullable<>)
&& (memberName == nameof(Nullable<int>.Value) || memberName == nameof(Nullable<int>.HasValue)));

var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
&& shouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name))
{
updatedMemberExpression = ConvertToNullable(updatedMemberExpression);

return Expression.Condition(
Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)),
Expression.Default(updatedMemberExpression.Type),
updatedMemberExpression);
}

return updatedMemberExpression;
}

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
Expand Down Expand Up @@ -150,7 +196,12 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
}

result = BindProperty(entityProjection, property);
if (result.Type != type)

// if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation
if (result.Type != type
&& !(result.Type.IsNullableType()
&& !type.IsNullableType()
&& result.Type.UnwrapNullableType() == type))
{
result = Expression.Convert(result, type);
}
Expand All @@ -161,10 +212,23 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
return false;
}

private static bool IsConvertedToNullable(Expression result, Expression original)
=> result.Type.IsNullableType()
&& !original.Type.IsNullableType()
&& result.Type.UnwrapNullableType() == original.Type;

private static Expression ConvertToNullable(Expression expression)
=> !expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.MakeNullable())
: expression;

private static Expression ConvertToNonNullable(Expression expression)
=> expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.UnwrapNullableType())
: expression;

private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property)
{
return entityProjectionExpression.BindProperty(property);
}
=> entityProjectionExpression.BindProperty(property);

private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
Expand Down Expand Up @@ -330,16 +394,45 @@ MethodInfo getMethod()
}

var arguments = new Expression[methodCallExpression.Arguments.Count];
var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray();
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
if (TranslationFailed(methodCallExpression.Arguments[i], argument))
{
return null;
}

// if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type,
// so we are forced to cast back to the original type
if (IsConvertedToNullable(argument, methodCallExpression.Arguments[i])
&& !parameterTypes[i].IsAssignableFrom(argument.Type))
{
argument = ConvertToNonNullable(argument);
}

arguments[i] = argument;
}

// if object is nullable, add null safeguard before calling the function
// we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard
if (methodCallExpression.Object != null
&& @object.Type.IsNullableType()
&& !(methodCallExpression.Method.Name == nameof(Nullable<int>.GetValueOrDefault)))
{
var result = (Expression)methodCallExpression.Update(
Expression.Convert(@object, methodCallExpression.Object.Type),
arguments);

result = ConvertToNullable(result);
result = Expression.Condition(
Expression.Equal(@object, Expression.Constant(null, @object.Type)),
Expression.Constant(null, result.Type),
result);

return result;
}

return methodCallExpression.Update(@object, arguments);
}

Expand Down Expand Up @@ -383,6 +476,51 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
return Expression.Constant(false);
}

protected override Expression VisitNew(NewExpression newExpression)
{
var newArguments = new List<Expression>();
foreach (var argument in newExpression.Arguments)
{
var newArgument = Visit(argument);
if (IsConvertedToNullable(newArgument, argument))
{
newArgument = ConvertToNonNullable(newArgument);
}

newArguments.Add(newArgument);
}

return newExpression.Update(newArguments);
}

protected override Expression VisitNewArray(NewArrayExpression newArrayExpression)
{
var newExpressions = new List<Expression>();
foreach (var expression in newArrayExpression.Expressions)
{
var newExpression = Visit(expression);
if (IsConvertedToNullable(newExpression, expression))
{
newExpression = ConvertToNonNullable(newExpression);
}

newExpressions.Add(newExpression);
}

return newArrayExpression.Update(newExpressions);
}

protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment)
{
var expression = Visit(memberAssignment.Expression);
if (IsConvertedToNullable(expression, memberAssignment.Expression))
{
expression = ConvertToNonNullable(expression);
}

return memberAssignment.Update(expression);
}

protected override Expression VisitExtension(Expression extensionExpression)
{
switch (extensionExpression)
Expand Down Expand Up @@ -441,7 +579,15 @@ private static T GetParameterValue<T>(QueryContext queryContext, string paramete

protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var result = base.VisitUnary(unaryExpression);
var newOperand = Visit(unaryExpression.Operand);

if (unaryExpression.NodeType == ExpressionType.Convert
&& newOperand.Type == unaryExpression.Type)
{
return newOperand;
}

var result = (Expression)Expression.MakeUnary(unaryExpression.NodeType, newOperand, unaryExpression.Type);
if (result is UnaryExpression outerUnary
&& outerUnary.NodeType == ExpressionType.Convert
&& outerUnary.Operand is UnaryExpression innerUnary
Expand Down
15 changes: 13 additions & 2 deletions src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,20 @@ public static MethodInfo GetMinWithSelector(Type type)
private static Dictionary<Type, MethodInfo> SumWithoutSelectorMethods { get; }
private static Dictionary<Type, MethodInfo> SumWithSelectorMethods { get; }

private static bool IsFunc(Type type, int funcGenericArgs = 2)
private static Type GetFuncType(int funcGenericArguments)
{
return funcGenericArguments switch
{
1 => typeof(Func<>),
2 => typeof(Func<,>),
3 => typeof(Func<,,>),
4 => typeof(Func<,,,>),
_ => throw new InvalidOperationException("Invalid number of arguments for Func"),
};
}
private static bool IsFunc(Type type, int funcGenericArguments = 2)
=> type.IsGenericType
&& type.GetGenericArguments().Length == funcGenericArgs;
&& type.GetGenericTypeDefinition() == GetFuncType(funcGenericArguments);

static InMemoryLinqOperatorProvider()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,17 @@ public override Expression Visit(Expression expression)
}

var translation = _expressionTranslatingExpressionVisitor.Translate(expression);
return translation == null
? base.Visit(expression)
: new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type);
if (translation == null)
{
return base.Visit(expression);
}

if (translation.Type != expression.Type)
{
translation = Expression.Convert(translation, expression.Type);
}

return new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type);
}
else
{
Expand All @@ -139,6 +147,11 @@ public override Expression Visit(Expression expression)
return null;
}

if (translation.Type != expression.Type)
{
translation = Expression.Convert(translation, expression.Type);
}

_projectionMapping[_projectionMembers.Peek()] = translation;

return new ProjectionBindingExpression(_queryExpression, _projectionMembers.Peek(), expression.Type);
Expand Down
Loading