Skip to content

Commit

Permalink
Query: Unwrap convert nodes around entity before doing member lookup
Browse files Browse the repository at this point in the history
Resolves #17794
  • Loading branch information
smitpatel committed Oct 30, 2019
1 parent 80cec2c commit b4103f5
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 275 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,38 +95,18 @@ protected override Expression VisitExtension(Expression node)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMember(MemberExpression memberExpression)
{
var innerExpression = Visit(memberExpression.Expression);

if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), out var result))
{
return result;
}

return TranslationFailed(memberExpression.Expression, innerExpression)
? null
: _memberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type);
}
=> TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result)
? result
: TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
? null
: _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type);

private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression)
{
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}

source = unaryExpression.Operand;
}

source = Visit(source.UnwrapTypeConversion(out var convertedType));
if (source is EntityProjectionExpression entityProjectionExpression)
{
if (convertedType != null
&& convertedType.IsInterface
&& convertedType.IsAssignableFrom(entityProjectionExpression.Type))
if (convertedType == null)
{
convertedType = entityProjectionExpression.Type;
}
Expand Down Expand Up @@ -162,30 +142,29 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), out var result)
return TryBindMember(source, MemberIdentity.Create(propertyName), out var result)
? result
: null;
}

var @object = Visit(methodCallExpression.Object);
if (TranslationFailed(methodCallExpression.Object, @object))
if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject))
{
return null;
}

var arguments = new SqlExpression[methodCallExpression.Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
if (TranslationFailed(methodCallExpression.Arguments[i], argument))
var argument = methodCallExpression.Arguments[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return null;
}

arguments[i] = (SqlExpression)argument;
arguments[i] = sqlArgument;
}

return _methodCallTranslatorProvider.Translate(_model, (SqlExpression)@object, methodCallExpression.Method, arguments);
return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments);
}

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down Expand Up @@ -240,17 +219,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
left = Visit(left);
right = Visit(right);

if (TranslationFailed(binaryExpression.Left, left)
|| TranslationFailed(binaryExpression.Right, right))
{
return null;
}

return _sqlExpressionFactory.MakeBinary(
binaryExpression.NodeType,
(SqlExpression)left,
(SqlExpression)right,
null);
return TranslationFailed(binaryExpression.Left, left, out var sqlLeft)
|| TranslationFailed(binaryExpression.Right, right, out var sqlRight)
? null
: _sqlExpressionFactory.MakeBinary(
binaryExpression.NodeType,
sqlLeft,
sqlRight,
null);
}

/// <summary>
Expand All @@ -265,14 +241,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (TranslationFailed(conditionalExpression.Test, test)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse))
{
return null;
}

return _sqlExpressionFactory.Condition((SqlExpression)test, (SqlExpression)ifTrue, (SqlExpression)ifFalse);
return TranslationFailed(conditionalExpression.Test, test, out var sqlTest)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)
? null
: _sqlExpressionFactory.Condition(sqlTest, sqlIfTrue, sqlIfFalse);
}

/// <summary>
Expand All @@ -285,18 +258,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
var operand = Visit(unaryExpression.Operand);

if (operand is EntityProjectionExpression)
{
return unaryExpression.Update(operand);
}

if (TranslationFailed(unaryExpression.Operand, operand))
if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand))
{
return null;
}

var sqlOperand = (SqlExpression)operand;

switch (unaryExpression.NodeType)
{
case ExpressionType.Not:
Expand Down Expand Up @@ -334,7 +300,9 @@ private SqlConstantExpression GetConstantOrNull(Expression expression)

private static bool CanEvaluate(Expression expression)
{
#pragma warning disable IDE0066 // Convert switch statement to expression
switch (expression)
#pragma warning restore IDE0066 // Convert switch statement to expression
{
case ConstantExpression constantExpression:
return true;
Expand Down Expand Up @@ -450,7 +418,17 @@ protected override Expression VisitExtension(Expression extensionExpression)
}

[DebuggerStepThrough]
private bool TranslationFailed(Expression original, Expression translation)
=> original != null && !(translation is SqlExpression);
private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation)
{
if (original != null
&& !(translation is SqlExpression))
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,22 @@ protected override Expression VisitConditional(ConditionalExpression conditional

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (TryBindMember(
memberExpression.Expression,
MemberIdentity.Create(memberExpression.Member),
memberExpression.Type,
out var result))
{
return result;
}

var innerExpression = Visit(memberExpression.Expression);
if (memberExpression.Expression != null
&& innerExpression == null)
{
return null;
}

if ((innerExpression is EntityProjectionExpression
|| (innerExpression is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityProjectionExpression))
&& TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type, out var result))
{
return result;
}

var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
Expand All @@ -164,24 +164,12 @@ static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string mem

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
{
source = Visit(source.UnwrapTypeConversion(out var convertedType));
result = null;
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
source = unaryExpression.Operand;
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
}

if (source is EntityProjectionExpression entityProjection)
{
var entityType = entityProjection.EntityType;
if (convertedType != null
&& !(convertedType.IsInterface
&& convertedType.IsAssignableFrom(entityType.ClrType)))
if (convertedType != null)
{
entityType = entityType.GetRootType().GetDerivedTypesInclusive()
.FirstOrDefault(et => et.ClrType == convertedType);
Expand Down Expand Up @@ -284,7 +272,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// EF.Property case
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result))
if (TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result))
{
return result;
}
Expand Down Expand Up @@ -396,16 +384,11 @@ MethodInfo GetMethod()
{
Expression result;
var innerExpression = ((NewArrayExpression)newValueBufferExpression.Arguments[0]).Expressions[0];
if (innerExpression is UnaryExpression unaryExpression
result = innerExpression is UnaryExpression unaryExpression
&& innerExpression.NodeType == ExpressionType.Convert
&& innerExpression.Type == typeof(object))
{
result = unaryExpression.Operand;
}
else
{
result = innerExpression;
}
&& innerExpression.Type == typeof(object)
? unaryExpression.Operand
: innerExpression;

return result.Type == methodCallExpression.Type
? result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,39 +828,18 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
var innerExpression = Visit(memberExpression.Expression);

if (innerExpression is EntityShaperExpression
|| (innerExpression is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityShaperExpression))
{
var collectionNavigation = Expand(innerExpression, MemberIdentity.Create(memberExpression.Member));
if (collectionNavigation != null)
{
return collectionNavigation;
}
}

return memberExpression.Update(innerExpression);
return TryExpand(innerExpression, MemberIdentity.Create(memberExpression.Member))
?? memberExpression.Update(innerExpression);
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var navigationName))
{
source = Visit(source);
if (source is EntityShaperExpression
|| (source is UnaryExpression innerUnaryExpression
&& innerUnaryExpression.NodeType == ExpressionType.Convert
&& innerUnaryExpression.Operand is EntityShaperExpression))
{
var collectionNavigation = Expand(source, MemberIdentity.Create(navigationName));
if (collectionNavigation != null)
{
return collectionNavigation;
}
}

return methodCallExpression.Update(null, new[] { source, methodCallExpression.Arguments[1] });
return TryExpand(source, MemberIdentity.Create(navigationName))
?? methodCallExpression.Update(null, new[] { source, methodCallExpression.Arguments[1] });
}

return base.VisitMethodCall(methodCallExpression);
Expand All @@ -871,19 +850,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
? extensionExpression
: base.VisitExtension(extensionExpression);

private Expression Expand(Expression source, MemberIdentity member)
private Expression TryExpand(Expression source, MemberIdentity member)
{
Type convertedType = null;
if (source is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
source = unaryExpression.Operand;
if (unaryExpression.Type != typeof(object))
{
convertedType = unaryExpression.Type;
}
}

source = source.UnwrapTypeConversion(out var convertedType);
if (!(source is EntityShaperExpression entityShaperExpression))
{
return null;
Expand Down Expand Up @@ -1016,17 +985,7 @@ private ShapedQueryExpression TranslateScalarAggregate(
return null;
}

MethodInfo getMethod()
=> methodName switch
{
nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType),
nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType),
nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType),
nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType),
_ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."),
};

var method = getMethod();
var method = GetMethod();
method = method.GetGenericArguments().Length == 2
? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
: method.MakeGenericMethod(typeof(ValueBuffer));
Expand All @@ -1040,6 +999,16 @@ MethodInfo getMethod()
source.ShaperExpression = inMemoryQueryExpression.GetSingleScalarProjection();

return source;

MethodInfo GetMethod()
=> methodName switch
{
nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType),
nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType),
nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType),
nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType),
_ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."),
};
}

private ShapedQueryExpression TranslateSingleResultOperator(
Expand Down
Loading

0 comments on commit b4103f5

Please sign in to comment.