Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Unwrap convert nodes around entity before doing member lookup #18670

Merged
merged 1 commit into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,57 +95,44 @@ protected override Expression VisitExtension(Expression node)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMember(MemberExpression memberExpression)
{
var innerExpression = Visit(memberExpression.Expression);

if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), out var result))
{
return result;
}

return TranslationFailed(memberExpression.Expression, innerExpression)
? null
: _memberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type);
}
=> TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
smitpatel marked this conversation as resolved.
Show resolved Hide resolved
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
source = source.UnwrapTypeConversion(out var convertedType);
Expression visitedExpression;
switch (source)
{
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
case EntityShaperExpression entityShaperExpression:
visitedExpression = Visit(entityShaperExpression.ValueBufferExpression);
break;

case MemberExpression memberExpression:
TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out visitedExpression);
break;

case MethodCallExpression methodCallExpression
when methodCallExpression.TryGetEFPropertyArguments(out var innerSource, out var innerPropertyName):
TryBindMember(innerSource, MemberIdentity.Create(innerPropertyName), out visitedExpression);
break;

source = unaryExpression.Operand;
default:
visitedExpression = null;
break;
}

if (source is EntityProjectionExpression entityProjectionExpression)
if (visitedExpression is EntityProjectionExpression entityProjectionExpression)
{
if (convertedType != null
&& convertedType.IsInterface
&& convertedType.IsAssignableFrom(entityProjectionExpression.Type))
{
convertedType = entityProjectionExpression.Type;
}

convertedType ??= entityProjectionExpression.Type;
expression = member.MemberInfo != null
? entityProjectionExpression.BindMember(member.MemberInfo, convertedType, clientEval: false, out _)
: entityProjectionExpression.BindMember(member.Name, convertedType, clientEval: false, out _);
return expression != null;
}

if (source is MemberExpression innerMemberExpression
&& TryBindMember(innerMemberExpression, MemberIdentity.Create(innerMemberExpression.Member), out var innerResult))
{
if (convertedType != null)
{
innerResult = Expression.Convert(innerResult, convertedType);
}

return TryBindMember(innerResult, member, out expression);
return expression != null;
}

expression = null;
Expand All @@ -162,30 +149,29 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result)
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null;
}

var @object = Visit(methodCallExpression.Object);
if (TranslationFailed(methodCallExpression.Object, @object))
if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
{
return null;
}

var arguments = new SqlExpression[methodCallExpression.Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
if (TranslationFailed(methodCallExpression.Arguments[i], argument))
var argument = methodCallExpression.Arguments[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return null;
}

arguments[i] = (SqlExpression)argument;
arguments[i] = sqlArgument;
}

return _methodCallTranslatorProvider.Translate(_model, (SqlExpression)@object, methodCallExpression.Method, arguments);
return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments);
}

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down Expand Up @@ -240,17 +226,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
left = Visit(left);
right = Visit(right);

if (TranslationFailed(binaryExpression.Left, left)
|| TranslationFailed(binaryExpression.Right, right))
{
return null;
}

return _sqlExpressionFactory.MakeBinary(
binaryExpression.NodeType,
(SqlExpression)left,
(SqlExpression)right,
null);
return TranslationFailed(binaryExpression.Left, left, out var sqlLeft)
|| TranslationFailed(binaryExpression.Right, right, out var sqlRight)
? null
: _sqlExpressionFactory.MakeBinary(
binaryExpression.NodeType,
sqlLeft,
sqlRight,
null);
}

/// <summary>
Expand All @@ -265,14 +248,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (TranslationFailed(conditionalExpression.Test, test)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse))
{
return null;
}

return _sqlExpressionFactory.Condition((SqlExpression)test, (SqlExpression)ifTrue, (SqlExpression)ifFalse);
return TranslationFailed(conditionalExpression.Test, test, out var sqlTest)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)
? null
: _sqlExpressionFactory.Condition(sqlTest, sqlIfTrue, sqlIfFalse);
}

/// <summary>
Expand All @@ -285,18 +265,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var operand = Visit(unaryExpression.Operand);

if (operand is EntityProjectionExpression)
{
return unaryExpression.Update(operand);
}

if (TranslationFailed(unaryExpression.Operand, operand))
if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand))
{
return null;
}

var sqlOperand = (SqlExpression)operand;

switch (unaryExpression.NodeType)
{
case ExpressionType.Not:
Expand Down Expand Up @@ -334,7 +307,9 @@ private SqlConstantExpression GetConstantOrNull(Expression expression)

private static bool CanEvaluate(Expression expression)
{
#pragma warning disable IDE0066 // Convert switch statement to expression
switch (expression)
#pragma warning restore IDE0066 // Convert switch statement to expression
{
case ConstantExpression constantExpression:
return true;
Expand Down Expand Up @@ -447,7 +422,17 @@ protected override Expression VisitExtension(Expression extensionExpression)
}

[DebuggerStepThrough]
private bool TranslationFailed(Expression original, Expression translation)
=> original != null && !(translation is SqlExpression);
private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation)
{
if (original != null
&& !(translation is SqlExpression))
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,22 @@ protected override Expression VisitConditional(ConditionalExpression conditional

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (TryBindMember(
memberExpression.Expression,
MemberIdentity.Create(memberExpression.Member),
memberExpression.Type,
out var result))
{
return result;
}

var innerExpression = Visit(memberExpression.Expression);
smitpatel marked this conversation as resolved.
Show resolved Hide resolved
if (memberExpression.Expression != null
&& innerExpression == null)
{
return null;
}

if ((innerExpression is EntityProjectionExpression
|| (innerExpression is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityProjectionExpression))
&& TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type, out var result))
{
return result;
}

var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
Expand All @@ -164,24 +164,12 @@ static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string mem

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
{
source = source.UnwrapTypeConversion(out var convertedType);
result = null;
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
if (source is EntityShaperExpression entityShaperExpression)
{
source = unaryExpression.Operand;
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
}

if (source is EntityProjectionExpression entityProjection)
{
var entityType = entityProjection.EntityType;
if (convertedType != null
&& !(convertedType.IsInterface
&& convertedType.IsAssignableFrom(entityType.ClrType)))
var entityType = entityShaperExpression.EntityType;
if (convertedType != null)
{
entityType = entityType.GetRootType().GetDerivedTypesInclusive()
.FirstOrDefault(et => et.ClrType == convertedType);
Expand All @@ -194,24 +182,25 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
var property = memberIdentity.MemberInfo != null
? entityType.FindProperty(memberIdentity.MemberInfo)
: entityType.FindProperty(memberIdentity.Name);
// If unmapped property return null
if (property == null)
if (property != null
&& Visit(entityShaperExpression.ValueBufferExpression) is EntityProjectionExpression entityProjectionExpression
&& (entityProjectionExpression.EntityType.IsAssignableFrom(property.DeclaringEntityType)
|| property.DeclaringEntityType.IsAssignableFrom(entityProjectionExpression.EntityType)))
{
return false;
}

result = BindProperty(entityProjection, property);
result = BindProperty(entityProjectionExpression, property);

// if the result type change was just nullability change e.g from int to int?
// we want to preserve the new type for null propagation
if (result.Type != type
&& !(result.Type.IsNullableType()
&& !type.IsNullableType()
&& result.Type.UnwrapNullableType() == type))
{
result = Expression.Convert(result, type);
}

// if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation
if (result.Type != type
&& !(result.Type.IsNullableType()
&& !type.IsNullableType()
&& result.Type.UnwrapNullableType() == type))
{
result = Expression.Convert(result, type);
return true;
}

return true;
}

return false;
Expand Down Expand Up @@ -284,7 +273,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// EF.Property case
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result))
if (TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result))
{
return result;
}
Expand Down Expand Up @@ -396,16 +385,11 @@ MethodInfo GetMethod()
{
Expression result;
var innerExpression = ((NewArrayExpression)newValueBufferExpression.Arguments[0]).Expressions[0];
if (innerExpression is UnaryExpression unaryExpression
result = innerExpression is UnaryExpression unaryExpression
&& innerExpression.NodeType == ExpressionType.Convert
&& innerExpression.Type == typeof(object))
{
result = unaryExpression.Operand;
}
else
{
result = innerExpression;
}
&& innerExpression.Type == typeof(object)
? unaryExpression.Operand
: innerExpression;

return result.Type == methodCallExpression.Type
? result
Expand Down
Loading