Skip to content

Commit

Permalink
Add null propagation/protection logic for InMemory provider.
Browse files Browse the repository at this point in the history
When we bind to a non-nullable property on entity that can be nullable (e.g. due to left join) we modify the result to be nullable, to avoid "nullable object must have a value" errors. This nullability is then propagated further. We have few blockers:
- predicates (Where, Any, First, Count, etc): always need to be of type bool. When necessary we add "== true".
- conditional expression Test: needs to be bool, same as above
- method call arguments: we can't reliably rewrite methodcall when the arguments types change (generic methods specifically), we convert arguments back to their original types if they were changed to nullable versions.
- method call caller: if the caller was changed from non-nullable to nullable we still need to call the method with the original type, but we add null check before - caller.Method(args) -> nullable_caller == null ? null : (resultType?)caller.Method(args)
- selectors (Select, Max etc): we need to preserve the original result type, we use convert
- anonymous type, array init: we need to preserve the original type, we use convert

Also enable GearsOfWar and ComplexNavigation tests for in memory.
  • Loading branch information
maumar authored and smitpatel committed Aug 29, 2019
1 parent 94cb8f3 commit 3836879
Show file tree
Hide file tree
Showing 12 changed files with 581 additions and 150 deletions.
17 changes: 16 additions & 1 deletion src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class EntityProjectionExpression : Expression
public class EntityProjectionExpression : Expression, IPrintableExpression
{
private readonly IDictionary<IProperty, Expression> _readExpressionMap;

Expand Down Expand Up @@ -50,5 +51,19 @@ public virtual Expression BindProperty(IProperty property)

return _readExpressionMap[property];
}

public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(EntityProjectionExpression) + ":");
using (expressionPrinter.Indent())
{
foreach (var readExpressionMapEntry in _readExpressionMap)
{
expressionPrinter.Append(readExpressionMapEntry.Key + " -> ");
expressionPrinter.Visit(readExpressionMapEntry.Value);
expressionPrinter.AppendLine();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -69,27 +70,59 @@ public virtual Expression Translate(Expression expression)

protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var left = Visit(binaryExpression.Left);
var right = Visit(binaryExpression.Right);
if (left == null || right == null)
var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

if (newLeft == null || newRight == null)
{
return null;
}

return binaryExpression.Update(left, binaryExpression.Conversion, right);
if (TypeNullabilityChanged(newLeft.Type, binaryExpression.Left.Type)
|| TypeNullabilityChanged(newRight.Type, binaryExpression.Right.Type))
{
newLeft = MakeNullable(newLeft);
newRight = MakeNullable(newRight);
}

return Expression.MakeBinary(
binaryExpression.NodeType,
newLeft,
newRight,
binaryExpression.IsLiftedToNull,
binaryExpression.Method,
binaryExpression.Conversion);
}

private static Expression MakeNullable(Expression expression)
=> !expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.MakeNullable())
: expression;

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
var test = Visit(conditionalExpression.Test);
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (test == null || ifTrue == null || ifFalse == null)
{
return null;
}

return conditionalExpression.Update(test, ifTrue, ifFalse);
if (test.Type == typeof(bool?))
{
test = Expression.Equal(test, Expression.Constant(true, typeof(bool?)));
}

if (TypeNullabilityChanged(ifTrue.Type, conditionalExpression.IfTrue.Type)
|| TypeNullabilityChanged(ifFalse.Type, conditionalExpression.IfFalse.Type))
{
ifTrue = MakeNullable(ifTrue);
ifFalse = MakeNullable(ifFalse);
}

return Expression.Condition(test, ifTrue, ifFalse);
}

protected override Expression VisitMember(MemberExpression memberExpression)
Expand All @@ -109,9 +142,27 @@ protected override Expression VisitMember(MemberExpression memberExpression)
return result;
}

return memberExpression.Update(innerExpression);
var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
&& ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name))
{
updatedMemberExpression = MakeNullable(updatedMemberExpression);

return Expression.Condition(
Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)),
Expression.Default(updatedMemberExpression.Type),
updatedMemberExpression);
}

return updatedMemberExpression;
}

private bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName)
=> !(callerType.IsGenericType
&& callerType.GetGenericTypeDefinition() == typeof(Nullable<>)
&& (memberName == nameof(Nullable<int>.Value) || memberName == nameof(Nullable<int>.HasValue)));

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
{
result = null;
Expand Down Expand Up @@ -150,7 +201,9 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
}

result = BindProperty(entityProjection, property);
if (result.Type != type)

// if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation
if (result.Type != type && !TypeNullabilityChanged(result.Type, type))
{
result = Expression.Convert(result, type);
}
Expand All @@ -161,6 +214,9 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
return false;
}

private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType)
=> maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType;

private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property)
{
return entityProjectionExpression.BindProperty(property);
Expand Down Expand Up @@ -330,16 +386,45 @@ MethodInfo getMethod()
}

var arguments = new Expression[methodCallExpression.Arguments.Count];
var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray();
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
if (TranslationFailed(methodCallExpression.Arguments[i], argument))
{
return null;
}

// if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type,
// so we are forced to cast back to the original type
if (argument.Type != methodCallExpression.Arguments[i].Type
&& !parameterTypes[i].IsAssignableFrom(argument.Type))
{
argument = Expression.Convert(argument, methodCallExpression.Arguments[i].Type);
}

arguments[i] = argument;
}

// if object is nullable, add null safeguard before calling the function
// we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard
if (methodCallExpression.Object != null
&& @object.Type.IsNullableType()
&& !(methodCallExpression.Method.Name == nameof(Nullable<int>.GetValueOrDefault)))
{
var result = (Expression)methodCallExpression.Update(
Expression.Convert(@object, methodCallExpression.Object.Type),
arguments);

result = MakeNullable(result);
result = Expression.Condition(
Expression.Equal(@object, Expression.Constant(null, @object.Type)),
Expression.Constant(null, result.Type),
result);

return result;
}

return methodCallExpression.Update(@object, arguments);
}

Expand Down Expand Up @@ -383,6 +468,68 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
return Expression.Constant(false);
}

protected override Expression VisitNew(NewExpression newExpression)
{
var newArguments = new List<Expression>();
foreach (var argument in newExpression.Arguments)
{
var newArgument = Visit(argument);
if (newArgument.Type != argument.Type)
{
newArgument = Expression.Convert(newArgument, argument.Type);
}

newArguments.Add(newArgument);
}

return newExpression.Update(newArguments);
}

protected override Expression VisitNewArray(NewArrayExpression newArrayExpression)
{
var newExpressions = new List<Expression>();
foreach (var expression in newArrayExpression.Expressions)
{
var newExpression = Visit(expression);
if (newExpression.Type != expression.Type)
{
newExpression = Expression.Convert(newExpression, expression.Type);
}

newExpressions.Add(newExpression);
}

return newArrayExpression.Update(newExpressions);
}

protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression)
{
var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression);
var bindings = new List<MemberBinding>();
foreach (var binding in memberInitExpression.Bindings)
{
switch (binding)
{
case MemberAssignment memberAssignment:
var expression = Visit(memberAssignment.Expression);
if (expression.Type != memberAssignment.Expression.Type)
{
expression = Expression.Convert(expression, memberAssignment.Expression.Type);
}

bindings.Add(Expression.Bind(memberAssignment.Member, expression));
break;

default:
// TODO: MemberMemberBinding and MemberListBinding
bindings.Add(binding);
break;
}
}

return memberInitExpression.Update(newExpression, bindings);
}

protected override Expression VisitExtension(Expression extensionExpression)
{
switch (extensionExpression)
Expand Down Expand Up @@ -441,7 +588,15 @@ private static T GetParameterValue<T>(QueryContext queryContext, string paramete

protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var result = base.VisitUnary(unaryExpression);
var newOperand = Visit(unaryExpression.Operand);

if (unaryExpression.NodeType == ExpressionType.Convert
&& newOperand.Type == unaryExpression.Type)
{
return newOperand;
}

var result = (Expression)Expression.MakeUnary(unaryExpression.NodeType, newOperand, unaryExpression.Type);
if (result is UnaryExpression outerUnary
&& outerUnary.NodeType == ExpressionType.Convert
&& outerUnary.Operand is UnaryExpression innerUnary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,17 @@ public override Expression Visit(Expression expression)
}

var translation = _expressionTranslatingExpressionVisitor.Translate(expression);
return translation == null
? base.Visit(expression)
: new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type);
if (translation == null)
{
return base.Visit(expression);
}

if (translation.Type != expression.Type)
{
translation = Expression.Convert(translation, expression.Type);
}

return new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type);
}
else
{
Expand All @@ -139,6 +147,11 @@ public override Expression Visit(Expression expression)
return null;
}

if (translation.Type != expression.Type)
{
translation = Expression.Convert(translation, expression.Type);
}

_projectionMapping[_projectionMembers.Peek()] = translation;

return new ProjectionBindingExpression(_queryExpression, _projectionMembers.Peek(), expression.Type);
Expand Down
27 changes: 26 additions & 1 deletion src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public partial class InMemoryQueryExpression : Expression
public partial class InMemoryQueryExpression : Expression, IPrintableExpression
{
private static readonly ConstructorInfo _valueBufferConstructor
= typeof(ValueBuffer).GetConstructors().Single(ci => ci.GetParameters().Length == 1);
Expand Down Expand Up @@ -751,6 +751,31 @@ public virtual void AddSelectMany(InMemoryQueryExpression innerQueryExpression,
_projectionMapping = projectionMapping;
}

public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(InMemoryQueryExpression) + ": ");
using (expressionPrinter.Indent())
{
expressionPrinter.AppendLine(nameof(ServerQueryExpression) + ": ");
using (expressionPrinter.Indent())
{
expressionPrinter.Visit(ServerQueryExpression);
}

expressionPrinter.AppendLine("ProjectionMapping:");
using (expressionPrinter.Indent())
{
foreach (var projectionMapping in _projectionMapping)
{
expressionPrinter.Append("Member: " + projectionMapping.Key + " Projection: ");
expressionPrinter.Visit(projectionMapping.Value);
}
}

expressionPrinter.AppendLine();
}
}

private class NullableReadValueExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
Expand Down
Loading

0 comments on commit 3836879

Please sign in to comment.