From c3fab83946f9c9b0c901c1b442f5e3aae5f6d8f2 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Wed, 28 Aug 2019 19:03:06 -0700 Subject: [PATCH] InMemory: Some refactorings Part of #16963 --- ...yExpressionTranslatingExpressionVisitor.cs | 105 ++++++++---------- .../Internal/InMemoryLinqOperatorProvider.cs | 15 ++- .../Query/Internal/InMemoryQueryExpression.cs | 6 +- ...yableMethodTranslatingExpressionVisitor.cs | 82 ++++++-------- 4 files changed, 98 insertions(+), 110 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index cc54fa1e294..714b4d01dcf 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -78,11 +78,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return null; } - if (TypeNullabilityChanged(newLeft.Type, binaryExpression.Left.Type) - || TypeNullabilityChanged(newRight.Type, binaryExpression.Right.Type)) + if (IsConvertedToNullable(newLeft, binaryExpression.Left) + || IsConvertedToNullable(newRight, binaryExpression.Right)) { - newLeft = MakeNullable(newLeft); - newRight = MakeNullable(newRight); + newLeft = ConvertToNullable(newLeft); + newRight = ConvertToNullable(newRight); } return Expression.MakeBinary( @@ -94,11 +94,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) binaryExpression.Conversion); } - private static Expression MakeNullable(Expression expression) - => !expression.Type.IsNullableType() - ? Expression.Convert(expression, expression.Type.MakeNullable()) - : expression; - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { var test = Visit(conditionalExpression.Test); @@ -115,11 +110,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional test = Expression.Equal(test, Expression.Constant(true, typeof(bool?))); } - if (TypeNullabilityChanged(ifTrue.Type, conditionalExpression.IfTrue.Type) - || TypeNullabilityChanged(ifFalse.Type, conditionalExpression.IfFalse.Type)) + if (IsConvertedToNullable(ifTrue, conditionalExpression.IfTrue) + || IsConvertedToNullable(ifFalse, conditionalExpression.IfFalse)) { - ifTrue = MakeNullable(ifTrue); - ifFalse = MakeNullable(ifFalse); + ifTrue = ConvertToNullable(ifTrue); + ifFalse = ConvertToNullable(ifFalse); } return Expression.Condition(test, ifTrue, ifFalse); @@ -142,12 +137,17 @@ protected override Expression VisitMember(MemberExpression memberExpression) return result; } + static bool shouldApplyNullProtectionForMemberAccess(Type callerType, string memberName) + => !(callerType.IsGenericType + && callerType.GetGenericTypeDefinition() == typeof(Nullable<>) + && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); + var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); if (innerExpression != null && innerExpression.Type.IsNullableType() - && ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name)) + && shouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name)) { - updatedMemberExpression = MakeNullable(updatedMemberExpression); + updatedMemberExpression = ConvertToNullable(updatedMemberExpression); return Expression.Condition( Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)), @@ -158,11 +158,6 @@ protected override Expression VisitMember(MemberExpression memberExpression) return updatedMemberExpression; } - private bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName) - => !(callerType.IsGenericType - && callerType.GetGenericTypeDefinition() == typeof(Nullable<>) - && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); - private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result) { result = null; @@ -203,7 +198,10 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ result = BindProperty(entityProjection, property); // if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation - if (result.Type != type && !TypeNullabilityChanged(result.Type, type)) + if (result.Type != type + && !(result.Type.IsNullableType() + && !type.IsNullableType() + && result.Type.UnwrapNullableType() == type)) { result = Expression.Convert(result, type); } @@ -214,13 +212,23 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ return false; } - private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType) - => maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType; + private static bool IsConvertedToNullable(Expression result, Expression original) + => result.Type.IsNullableType() + && !original.Type.IsNullableType() + && result.Type.UnwrapNullableType() == original.Type; + + private static Expression ConvertToNullable(Expression expression) + => !expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.MakeNullable()) + : expression; + + private static Expression ConvertToNonNullable(Expression expression) + => expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) + : expression; private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) - { - return entityProjectionExpression.BindProperty(property); - } + => entityProjectionExpression.BindProperty(property); private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) { @@ -386,7 +394,7 @@ MethodInfo getMethod() } var arguments = new Expression[methodCallExpression.Arguments.Count]; - var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); + var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); for (var i = 0; i < arguments.Length; i++) { var argument = Visit(methodCallExpression.Arguments[i]); @@ -397,10 +405,10 @@ MethodInfo getMethod() // if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type, // so we are forced to cast back to the original type - if (argument.Type != methodCallExpression.Arguments[i].Type + if (IsConvertedToNullable(argument, methodCallExpression.Arguments[i]) && !parameterTypes[i].IsAssignableFrom(argument.Type)) { - argument = Expression.Convert(argument, methodCallExpression.Arguments[i].Type); + argument = ConvertToNonNullable(argument); } arguments[i] = argument; @@ -412,11 +420,11 @@ MethodInfo getMethod() && @object.Type.IsNullableType() && !(methodCallExpression.Method.Name == nameof(Nullable.GetValueOrDefault))) { - var result = (Expression)methodCallExpression.Update( + var result = (Expression)methodCallExpression.Update( Expression.Convert(@object, methodCallExpression.Object.Type), arguments); - result = MakeNullable(result); + result = ConvertToNullable(result); result = Expression.Condition( Expression.Equal(@object, Expression.Constant(null, @object.Type)), Expression.Constant(null, result.Type), @@ -474,9 +482,9 @@ protected override Expression VisitNew(NewExpression newExpression) foreach (var argument in newExpression.Arguments) { var newArgument = Visit(argument); - if (newArgument.Type != argument.Type) + if (IsConvertedToNullable(newArgument, argument)) { - newArgument = Expression.Convert(newArgument, argument.Type); + newArgument = ConvertToNonNullable(newArgument); } newArguments.Add(newArgument); @@ -491,9 +499,9 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio foreach (var expression in newArrayExpression.Expressions) { var newExpression = Visit(expression); - if (newExpression.Type != expression.Type) + if (IsConvertedToNullable(newExpression, expression)) { - newExpression = Expression.Convert(newExpression, expression.Type); + newExpression = ConvertToNonNullable(newExpression); } newExpressions.Add(newExpression); @@ -502,32 +510,15 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio return newArrayExpression.Update(newExpressions); } - protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) { - var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression); - var bindings = new List(); - foreach (var binding in memberInitExpression.Bindings) + var expression = Visit(memberAssignment.Expression); + if (IsConvertedToNullable(expression, memberAssignment.Expression)) { - switch (binding) - { - case MemberAssignment memberAssignment: - var expression = Visit(memberAssignment.Expression); - if (expression.Type != memberAssignment.Expression.Type) - { - expression = Expression.Convert(expression, memberAssignment.Expression.Type); - } - - bindings.Add(Expression.Bind(memberAssignment.Member, expression)); - break; - - default: - // TODO: MemberMemberBinding and MemberListBinding - bindings.Add(binding); - break; - } + expression = ConvertToNonNullable(expression); } - return memberInitExpression.Update(newExpression, bindings); + return memberAssignment.Update(expression); } protected override Expression VisitExtension(Expression extensionExpression) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs b/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs index 3d7f53aee47..34b561f941a 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs @@ -108,9 +108,20 @@ public static MethodInfo GetMinWithSelector(Type type) private static Dictionary SumWithoutSelectorMethods { get; } private static Dictionary SumWithSelectorMethods { get; } - private static bool IsFunc(Type type, int funcGenericArgs = 2) + private static Type GetFuncType(int funcGenericArguments) + { + return funcGenericArguments switch + { + 1 => typeof(Func<>), + 2 => typeof(Func<,>), + 3 => typeof(Func<,,>), + 4 => typeof(Func<,,,>), + _ => throw new InvalidOperationException("Invalid number of arguments for Func"), + }; + } + private static bool IsFunc(Type type, int funcGenericArguments = 2) => type.IsGenericType - && type.GetGenericArguments().Length == funcGenericArgs; + && type.GetGenericTypeDefinition() == GetFuncType(funcGenericArguments); static InMemoryLinqOperatorProvider() { diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index a198ddf478e..d6b8c8fb669 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -51,8 +51,8 @@ public InMemoryQueryExpression(IEntityType entityType) Constant(property.GetIndex()), MakeMemberAccess(_valueBufferParameter, _valueBufferCountMemberInfo)), - CreateReadValueExpression(typeof(object), property.GetIndex(), property), - Default(typeof(object))); + CreateReadValueExpression(property.ClrType, property.GetIndex(), property), + Default(property.ClrType)); } var entityProjection = new EntityProjectionExpression(entityType, readExpressionMap); @@ -256,7 +256,7 @@ public virtual void ApplyDefaultIfEmpty() { if (_valueBufferSlots.Count != 0) { - throw new InvalidOperationException("Cannot Apply DefaultIfEmpty after ClientProjection."); + throw new InvalidOperationException("Cannot apply DefaultIfEmpty after a client-evaluated projection."); } var result = new Dictionary(); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index d1d8fad4e53..39e405bddba 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -6,8 +6,6 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; -using Microsoft.EntityFrameworkCore.Diagnostics; -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query; @@ -131,23 +129,17 @@ protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression s protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item) { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - var itemType = item.Type; - item = TranslateExpression(item); + item = TranslateExpression(item, preserveType: true); if (item == null) { return null; } - if (item.Type != itemType) - { - item = Expression.Convert(item, itemType); - } - inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(itemType), + InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(item.Type), Expression.Call( - InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, itemType), + InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, item.Type), inMemoryQueryExpression.ServerQueryExpression, Expression.Lambda( inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter)), @@ -349,9 +341,7 @@ protected override ShapedQueryExpression TranslateJoin(ShapedQueryExpression out return null; } - var unifyNullabilityResult = UnifyNullability(outerKeySelector, innerKeySelector); - outerKeySelector = unifyNullabilityResult.lambda1; - innerKeySelector = unifyNullabilityResult.lambda2; + (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, @@ -370,27 +360,31 @@ protected override ShapedQueryExpression TranslateJoin(ShapedQueryExpression out transparentIdentifierType); } - private (LambdaExpression lambda1, LambdaExpression lambda2) UnifyNullability(LambdaExpression lambda1, LambdaExpression lambda2) + private (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) + AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector) { - if (lambda1.Body.Type != lambda2.Body.Type) + static bool isConvertedToNullable(Expression outer, Expression inner) + => outer.Type.IsNullableType() + && !inner.Type.IsNullableType() + && outer.Type.UnwrapNullableType() == inner.Type; + + if (outerKeySelector.Body.Type != innerKeySelector.Body.Type) { - if (TypeNullabilityChanged(lambda1.Body.Type, lambda2.Body.Type)) + if (isConvertedToNullable(outerKeySelector.Body, innerKeySelector.Body)) { - lambda2 = Expression.Lambda(Expression.Convert(lambda2.Body, lambda1.Body.Type), lambda2.Parameters); + innerKeySelector = Expression.Lambda( + Expression.Convert(innerKeySelector.Body, outerKeySelector.Body.Type), innerKeySelector.Parameters); } - else if (TypeNullabilityChanged(lambda2.Body.Type, lambda1.Body.Type)) + else if (isConvertedToNullable(innerKeySelector.Body, outerKeySelector.Body)) { - lambda1 = Expression.Lambda(Expression.Convert(lambda1.Body, lambda2.Body.Type), lambda1.Parameters); + outerKeySelector = Expression.Lambda( + Expression.Convert(outerKeySelector.Body, innerKeySelector.Body.Type), outerKeySelector.Parameters); } } - return (lambda1, lambda2); + return (outerKeySelector, innerKeySelector); } - // TODO: DRY - private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType) - => maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType; - protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpression source, LambdaExpression predicate, Type returnType, bool returnDefault) { return TranslateSingleResultOperator( @@ -411,9 +405,7 @@ protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression return null; } - var unifyNullabilityResult = UnifyNullability(outerKeySelector, innerKeySelector); - outerKeySelector = unifyNullabilityResult.lambda1; - innerKeySelector = unifyNullabilityResult.lambda2; + (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, @@ -752,9 +744,19 @@ protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression so return source; } - private Expression TranslateExpression(Expression expression) + private Expression TranslateExpression(Expression expression, bool preserveType = false) { - return _expressionTranslator.Translate(expression); + var result = _expressionTranslator.Translate(expression); + + if (expression != null && result != null + && preserveType && expression.Type != result.Type) + { + result = expression.Type == typeof(bool) + ? Expression.Equal(result, Expression.Constant(true, result.Type)) + : (Expression)Expression.Convert(result, expression.Type); + } + + return result; } private LambdaExpression TranslateLambdaExpression( @@ -762,16 +764,7 @@ private LambdaExpression TranslateLambdaExpression( LambdaExpression lambdaExpression, bool preserveType = false) { - var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression)); - - if (lambdaBody != null && preserveType) - { - lambdaBody = lambdaBody.Type == typeof(bool?) - ? Expression.Equal( - lambdaBody, - Expression.Constant(true, typeof(bool?))) - : lambdaBody; - } + var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression), preserveType); return lambdaBody != null ? Expression.Lambda(lambdaBody, @@ -790,25 +783,18 @@ private ShapedQueryExpression TranslateScalarAggregate( { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - var selectorBodyType = selector?.Body.Type; - selector = selector == null || selector.Body == selector.Parameters[0] ? Expression.Lambda( inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter) - : TranslateLambdaExpression(source, selector); + : TranslateLambdaExpression(source, selector, preserveType: true); if (selector == null) { return null; } - if (selectorBodyType != null && selector.Body.Type != selectorBodyType) - { - selector = Expression.Lambda(Expression.Convert(selector.Body, selectorBodyType), selector.Parameters); - } - MethodInfo getMethod() => methodName switch {