Skip to content

Commit

Permalink
Query: Unwrap convert nodes around entity before doing member lookup
Browse files Browse the repository at this point in the history
Resolves #17794
  • Loading branch information
smitpatel committed Nov 1, 2019
1 parent fe0d550 commit 9b0ceb2
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 307 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,57 +95,44 @@ protected override Expression VisitExtension(Expression node)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMember(MemberExpression memberExpression)
{
var innerExpression = Visit(memberExpression.Expression);

if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), out var result))
{
return result;
}

return TranslationFailed(memberExpression.Expression, innerExpression)
? null
: _memberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type);
}
=> TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
source = source.UnwrapTypeConversion(out var convertedType);
Expression visitedExpression;
switch (source)
{
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
case EntityShaperExpression entityShaperExpression:
visitedExpression = Visit(entityShaperExpression.ValueBufferExpression);
break;

case MemberExpression memberExpression:
TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out visitedExpression);
break;

case MethodCallExpression methodCallExpression
when methodCallExpression.TryGetEFPropertyArguments(out var innerSource, out var innerPropertyName):
TryBindMember(innerSource, MemberIdentity.Create(innerPropertyName), out visitedExpression);
break;

source = unaryExpression.Operand;
default:
visitedExpression = null;
break;
}

if (source is EntityProjectionExpression entityProjectionExpression)
if (visitedExpression is EntityProjectionExpression entityProjectionExpression)
{
if (convertedType != null
&& convertedType.IsInterface
&& convertedType.IsAssignableFrom(entityProjectionExpression.Type))
{
convertedType = entityProjectionExpression.Type;
}

convertedType ??= entityProjectionExpression.Type;
expression = member.MemberInfo != null
? entityProjectionExpression.BindMember(member.MemberInfo, convertedType, clientEval: false, out _)
: entityProjectionExpression.BindMember(member.Name, convertedType, clientEval: false, out _);
return expression != null;
}

if (source is MemberExpression innerMemberExpression
&& TryBindMember(innerMemberExpression, MemberIdentity.Create(innerMemberExpression.Member), out var innerResult))
{
if (convertedType != null)
{
innerResult = Expression.Convert(innerResult, convertedType);
}

return TryBindMember(innerResult, member, out expression);
return expression != null;
}

expression = null;
Expand All @@ -162,30 +149,29 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result)
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null;
}

var @object = Visit(methodCallExpression.Object);
if (TranslationFailed(methodCallExpression.Object, @object))
if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
{
return null;
}

var arguments = new SqlExpression[methodCallExpression.Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
if (TranslationFailed(methodCallExpression.Arguments[i], argument))
var argument = methodCallExpression.Arguments[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return null;
}

arguments[i] = (SqlExpression)argument;
arguments[i] = sqlArgument;
}

return _methodCallTranslatorProvider.Translate(_model, (SqlExpression)@object, methodCallExpression.Method, arguments);
return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments);
}

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down Expand Up @@ -240,17 +226,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
left = Visit(left);
right = Visit(right);

if (TranslationFailed(binaryExpression.Left, left)
|| TranslationFailed(binaryExpression.Right, right))
{
return null;
}

return _sqlExpressionFactory.MakeBinary(
binaryExpression.NodeType,
(SqlExpression)left,
(SqlExpression)right,
null);
return TranslationFailed(binaryExpression.Left, left, out var sqlLeft)
|| TranslationFailed(binaryExpression.Right, right, out var sqlRight)
? null
: _sqlExpressionFactory.MakeBinary(
binaryExpression.NodeType,
sqlLeft,
sqlRight,
null);
}

/// <summary>
Expand All @@ -265,14 +248,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (TranslationFailed(conditionalExpression.Test, test)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse))
{
return null;
}

return _sqlExpressionFactory.Condition((SqlExpression)test, (SqlExpression)ifTrue, (SqlExpression)ifFalse);
return TranslationFailed(conditionalExpression.Test, test, out var sqlTest)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)
? null
: _sqlExpressionFactory.Condition(sqlTest, sqlIfTrue, sqlIfFalse);
}

/// <summary>
Expand All @@ -285,18 +265,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var operand = Visit(unaryExpression.Operand);

if (operand is EntityProjectionExpression)
{
return unaryExpression.Update(operand);
}

if (TranslationFailed(unaryExpression.Operand, operand))
if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand))
{
return null;
}

var sqlOperand = (SqlExpression)operand;

switch (unaryExpression.NodeType)
{
case ExpressionType.Not:
Expand Down Expand Up @@ -334,7 +307,9 @@ private SqlConstantExpression GetConstantOrNull(Expression expression)

private static bool CanEvaluate(Expression expression)
{
#pragma warning disable IDE0066 // Convert switch statement to expression
switch (expression)
#pragma warning restore IDE0066 // Convert switch statement to expression
{
case ConstantExpression constantExpression:
return true;
Expand Down Expand Up @@ -447,7 +422,17 @@ protected override Expression VisitExtension(Expression extensionExpression)
}

[DebuggerStepThrough]
private bool TranslationFailed(Expression original, Expression translation)
=> original != null && !(translation is SqlExpression);
private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation)
{
if (original != null
&& !(translation is SqlExpression))
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,22 @@ protected override Expression VisitConditional(ConditionalExpression conditional

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (TryBindMember(
memberExpression.Expression,
MemberIdentity.Create(memberExpression.Member),
memberExpression.Type,
out var result))
{
return result;
}

var innerExpression = Visit(memberExpression.Expression);
if (memberExpression.Expression != null
&& innerExpression == null)
{
return null;
}

if ((innerExpression is EntityProjectionExpression
|| (innerExpression is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityProjectionExpression))
&& TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type, out var result))
{
return result;
}

var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
Expand All @@ -164,24 +164,12 @@ static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string mem

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
{
source = source.UnwrapTypeConversion(out var convertedType);
result = null;
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
if (source is EntityShaperExpression entityShaperExpression)
{
source = unaryExpression.Operand;
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
}

if (source is EntityProjectionExpression entityProjection)
{
var entityType = entityProjection.EntityType;
if (convertedType != null
&& !(convertedType.IsInterface
&& convertedType.IsAssignableFrom(entityType.ClrType)))
var entityType = entityShaperExpression.EntityType;
if (convertedType != null)
{
entityType = entityType.GetRootType().GetDerivedTypesInclusive()
.FirstOrDefault(et => et.ClrType == convertedType);
Expand All @@ -194,24 +182,25 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
var property = memberIdentity.MemberInfo != null
? entityType.FindProperty(memberIdentity.MemberInfo)
: entityType.FindProperty(memberIdentity.Name);
// If unmapped property return null
if (property == null)
if (property != null
&& Visit(entityShaperExpression.ValueBufferExpression) is EntityProjectionExpression entityProjectionExpression
&& (entityProjectionExpression.EntityType.IsAssignableFrom(property.DeclaringEntityType)
|| property.DeclaringEntityType.IsAssignableFrom(entityProjectionExpression.EntityType)))
{
return false;
}

result = BindProperty(entityProjection, property);
result = BindProperty(entityProjectionExpression, property);

// if the result type change was just nullability change e.g from int to int?
// we want to preserve the new type for null propagation
if (result.Type != type
&& !(result.Type.IsNullableType()
&& !type.IsNullableType()
&& result.Type.UnwrapNullableType() == type))
{
result = Expression.Convert(result, type);
}

// if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation
if (result.Type != type
&& !(result.Type.IsNullableType()
&& !type.IsNullableType()
&& result.Type.UnwrapNullableType() == type))
{
result = Expression.Convert(result, type);
return true;
}

return true;
}

return false;
Expand Down Expand Up @@ -284,7 +273,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// EF.Property case
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result))
if (TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result))
{
return result;
}
Expand Down Expand Up @@ -396,16 +385,11 @@ MethodInfo GetMethod()
{
Expression result;
var innerExpression = ((NewArrayExpression)newValueBufferExpression.Arguments[0]).Expressions[0];
if (innerExpression is UnaryExpression unaryExpression
result = innerExpression is UnaryExpression unaryExpression
&& innerExpression.NodeType == ExpressionType.Convert
&& innerExpression.Type == typeof(object))
{
result = unaryExpression.Operand;
}
else
{
result = innerExpression;
}
&& innerExpression.Type == typeof(object)
? unaryExpression.Operand
: innerExpression;

return result.Type == methodCallExpression.Type
? result
Expand Down
Loading

0 comments on commit 9b0ceb2

Please sign in to comment.