Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add null propagation/protection logic for InMemory provider. #17485

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class EntityProjectionExpression : Expression
public class EntityProjectionExpression : Expression, IPrintableExpression
{
private readonly IDictionary<IProperty, Expression> _readExpressionMap;

Expand Down Expand Up @@ -50,5 +51,19 @@ public virtual Expression BindProperty(IProperty property)

return _readExpressionMap[property];
}

public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(EntityProjectionExpression) + ":");
using (expressionPrinter.Indent())
{
foreach (var readExpressionMapEntry in _readExpressionMap)
{
expressionPrinter.Append(readExpressionMapEntry.Key + " -> ");
expressionPrinter.Visit(readExpressionMapEntry.Value);
expressionPrinter.AppendLine();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -69,27 +70,59 @@ public virtual Expression Translate(Expression expression)

protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var left = Visit(binaryExpression.Left);
var right = Visit(binaryExpression.Right);
if (left == null || right == null)
var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

if (newLeft == null || newRight == null)
{
return null;
}

return binaryExpression.Update(left, binaryExpression.Conversion, right);
if (TypeNullabilityChanged(newLeft.Type, binaryExpression.Left.Type)
|| TypeNullabilityChanged(newRight.Type, binaryExpression.Right.Type))
{
newLeft = MakeNullable(newLeft);
newRight = MakeNullable(newRight);
}

return Expression.MakeBinary(
binaryExpression.NodeType,
newLeft,
newRight,
binaryExpression.IsLiftedToNull,
binaryExpression.Method,
binaryExpression.Conversion);
}

private static Expression MakeNullable(Expression expression)
=> !expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.MakeNullable())
: expression;

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
var test = Visit(conditionalExpression.Test);
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (test == null || ifTrue == null || ifFalse == null)
{
return null;
}

return conditionalExpression.Update(test, ifTrue, ifFalse);
if (test.Type == typeof(bool?))
{
test = Expression.Equal(test, Expression.Constant(true, typeof(bool?)));
}

if (TypeNullabilityChanged(ifTrue.Type, conditionalExpression.IfTrue.Type)
|| TypeNullabilityChanged(ifFalse.Type, conditionalExpression.IfFalse.Type))
{
ifTrue = MakeNullable(ifTrue);
ifFalse = MakeNullable(ifFalse);
}

return Expression.Condition(test, ifTrue, ifFalse);
}

protected override Expression VisitMember(MemberExpression memberExpression)
Expand All @@ -109,9 +142,27 @@ protected override Expression VisitMember(MemberExpression memberExpression)
return result;
}

return memberExpression.Update(innerExpression);
var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
&& ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name))
{
updatedMemberExpression = MakeNullable(updatedMemberExpression);

return Expression.Condition(
Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)),
Expression.Default(updatedMemberExpression.Type),
updatedMemberExpression);
}

return updatedMemberExpression;
}

private bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName)
=> !(callerType.IsGenericType
&& callerType.GetGenericTypeDefinition() == typeof(Nullable<>)
&& (memberName == nameof(Nullable<int>.Value) || memberName == nameof(Nullable<int>.HasValue)));

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
{
result = null;
Expand Down Expand Up @@ -150,7 +201,9 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
}

result = BindProperty(entityProjection, property);
if (result.Type != type)

// if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation
if (result.Type != type && !TypeNullabilityChanged(result.Type, type))
{
result = Expression.Convert(result, type);
}
Expand All @@ -161,6 +214,9 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
return false;
}

private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType)
=> maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType;

private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property)
{
return entityProjectionExpression.BindProperty(property);
Expand Down Expand Up @@ -330,16 +386,45 @@ MethodInfo getMethod()
}

var arguments = new Expression[methodCallExpression.Arguments.Count];
var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray();
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
if (TranslationFailed(methodCallExpression.Arguments[i], argument))
{
return null;
}

// if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type,
// so we are forced to cast back to the original type
if (argument.Type != methodCallExpression.Arguments[i].Type
&& !parameterTypes[i].IsAssignableFrom(argument.Type))
{
argument = Expression.Convert(argument, methodCallExpression.Arguments[i].Type);
}

arguments[i] = argument;
}

// if object is nullable, add null safeguard before calling the function
// we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard
if (methodCallExpression.Object != null
&& @object.Type.IsNullableType()
&& !(methodCallExpression.Method.Name == nameof(Nullable<int>.GetValueOrDefault)))
{
var result = (Expression)methodCallExpression.Update(
Expression.Convert(@object, methodCallExpression.Object.Type),
arguments);

result = MakeNullable(result);
result = Expression.Condition(
Expression.Equal(@object, Expression.Constant(null, @object.Type)),
Expression.Constant(null, result.Type),
result);

return result;
}

return methodCallExpression.Update(@object, arguments);
}

Expand Down Expand Up @@ -383,6 +468,68 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
return Expression.Constant(false);
}

protected override Expression VisitNew(NewExpression newExpression)
{
var newArguments = new List<Expression>();
foreach (var argument in newExpression.Arguments)
{
var newArgument = Visit(argument);
if (newArgument.Type != argument.Type)
{
newArgument = Expression.Convert(newArgument, argument.Type);
}

newArguments.Add(newArgument);
}

return newExpression.Update(newArguments);
}

protected override Expression VisitNewArray(NewArrayExpression newArrayExpression)
{
var newExpressions = new List<Expression>();
foreach (var expression in newArrayExpression.Expressions)
{
var newExpression = Visit(expression);
if (newExpression.Type != expression.Type)
{
newExpression = Expression.Convert(newExpression, expression.Type);
}

newExpressions.Add(newExpression);
}

return newArrayExpression.Update(newExpressions);
}

protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression)
{
var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression);
var bindings = new List<MemberBinding>();
foreach (var binding in memberInitExpression.Bindings)
{
switch (binding)
{
case MemberAssignment memberAssignment:
var expression = Visit(memberAssignment.Expression);
if (expression.Type != memberAssignment.Expression.Type)
{
expression = Expression.Convert(expression, memberAssignment.Expression.Type);
}

bindings.Add(Expression.Bind(memberAssignment.Member, expression));
break;

default:
// TODO: MemberMemberBinding and MemberListBinding
bindings.Add(binding);
break;
}
}

return memberInitExpression.Update(newExpression, bindings);
}

protected override Expression VisitExtension(Expression extensionExpression)
{
switch (extensionExpression)
Expand Down Expand Up @@ -441,7 +588,15 @@ private static T GetParameterValue<T>(QueryContext queryContext, string paramete

protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var result = base.VisitUnary(unaryExpression);
var newOperand = Visit(unaryExpression.Operand);

if (unaryExpression.NodeType == ExpressionType.Convert
&& newOperand.Type == unaryExpression.Type)
{
return newOperand;
}

var result = (Expression)Expression.MakeUnary(unaryExpression.NodeType, newOperand, unaryExpression.Type);
if (result is UnaryExpression outerUnary
&& outerUnary.NodeType == ExpressionType.Convert
&& outerUnary.Operand is UnaryExpression innerUnary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,17 @@ public override Expression Visit(Expression expression)
}

var translation = _expressionTranslatingExpressionVisitor.Translate(expression);
return translation == null
? base.Visit(expression)
: new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type);
if (translation == null)
{
return base.Visit(expression);
}

if (translation.Type != expression.Type)
{
translation = Expression.Convert(translation, expression.Type);
}

return new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type);
}
else
{
Expand All @@ -139,6 +147,11 @@ public override Expression Visit(Expression expression)
return null;
}

if (translation.Type != expression.Type)
{
translation = Expression.Convert(translation, expression.Type);
}

_projectionMapping[_projectionMembers.Peek()] = translation;

return new ProjectionBindingExpression(_queryExpression, _projectionMembers.Peek(), expression.Type);
Expand Down
27 changes: 26 additions & 1 deletion src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public partial class InMemoryQueryExpression : Expression
public partial class InMemoryQueryExpression : Expression, IPrintableExpression
{
private static readonly ConstructorInfo _valueBufferConstructor
= typeof(ValueBuffer).GetConstructors().Single(ci => ci.GetParameters().Length == 1);
Expand Down Expand Up @@ -751,6 +751,31 @@ public virtual void AddSelectMany(InMemoryQueryExpression innerQueryExpression,
_projectionMapping = projectionMapping;
}

public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(InMemoryQueryExpression) + ": ");
using (expressionPrinter.Indent())
{
expressionPrinter.AppendLine(nameof(ServerQueryExpression) + ": ");
using (expressionPrinter.Indent())
{
expressionPrinter.Visit(ServerQueryExpression);
}

expressionPrinter.AppendLine("ProjectionMapping:");
using (expressionPrinter.Indent())
{
foreach (var projectionMapping in _projectionMapping)
{
expressionPrinter.Append("Member: " + projectionMapping.Key + " Projection: ");
expressionPrinter.Visit(projectionMapping.Value);
}
}

expressionPrinter.AppendLine();
}
}

private class NullableReadValueExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
Expand Down
Loading