From ec03e289fcf81f037239b87efafd6f546da7ae1e Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Wed, 28 Aug 2019 15:54:02 -0700 Subject: [PATCH 1/3] InMemory: Support for DefaultIfEmpty Part of #16963 --- .../Query/Internal/InMemoryQueryExpression.cs | 56 ++++++++++++++++++- ...yableMethodTranslatingExpressionVisitor.cs | 12 +++- src/EFCore/Query/EntityMaterializerSource.cs | 2 +- .../Query/AsyncSimpleQueryInMemoryTest.cs | 36 ------------ .../Query/SimpleQueryInMemoryTest.cs | 55 +++--------------- 5 files changed, 76 insertions(+), 85 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index 542d58dcc93..5de9275fde9 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -252,7 +252,61 @@ public virtual void PushdownIntoSubquery() } } - private IPropertyBase InferPropertyFromInner(Expression expression) + public virtual void ApplyDefaultIfEmpty() + { + if (_valueBufferSlots.Count != 0) + { + throw new InvalidOperationException("Cannot Apply DefaultIfEmpty after ClientProjection."); + } + + var result = new Dictionary(); + foreach (var keyValuePair in _projectionMapping) + { + if (keyValuePair.Value is EntityProjectionExpression entityProjection) + { + var map = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjection.EntityType)) + { + var index = AddToProjection(entityProjection.BindProperty(property)); + map[property] = CreateReadValueExpression(property.ClrType.MakeNullable(), index, property); + } + result[keyValuePair.Key] = new EntityProjectionExpression(entityProjection.EntityType, map); + } + else + { + var index = AddToProjection(keyValuePair.Value); + result[keyValuePair.Key] = CreateReadValueExpression( + keyValuePair.Value.Type.MakeNullable(), index, InferPropertyFromInner(keyValuePair.Value)); + } + } + + _projectionMapping = result; + + var selectorLambda = Lambda( + New( + _valueBufferConstructor, + NewArrayInit( + typeof(object), + _valueBufferSlots + .Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + CurrentParameter); + + _groupingParameter = null; + + ServerQueryExpression = Call( + InMemoryLinqOperatorProvider.Select.MakeGenericMethod(ServerQueryExpression.Type.TryGetSequenceType(), typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda); + + ServerQueryExpression = Call( + InMemoryLinqOperatorProvider.DefaultIfEmptyWithArgument.MakeGenericMethod(typeof(ValueBuffer)), + ServerQueryExpression, + New(_valueBufferConstructor, NewArrayInit(typeof(object), Enumerable.Repeat(Constant(null), _valueBufferSlots.Count)))); + + _valueBufferSlots.Clear(); + } + + private static IPropertyBase InferPropertyFromInner(Expression expression) { if (expression is MethodCallExpression methodCallExpression && methodCallExpression.Method.IsGenericMethod diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 7329a3e2b16..1944ed1f7ff 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -184,7 +184,17 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so } protected override ShapedQueryExpression TranslateDefaultIfEmpty(ShapedQueryExpression source, Expression defaultValue) - => null; + { + if (defaultValue == null) + { + ((InMemoryQueryExpression)source.QueryExpression).ApplyDefaultIfEmpty(); + source.ShaperExpression = MarkShaperNullable(source.ShaperExpression); + + return source; + } + + return null; + } protected override ShapedQueryExpression TranslateDistinct(ShapedQueryExpression source) { diff --git a/src/EFCore/Query/EntityMaterializerSource.cs b/src/EFCore/Query/EntityMaterializerSource.cs index 76b5b77c263..d4cbb842cac 100644 --- a/src/EFCore/Query/EntityMaterializerSource.cs +++ b/src/EFCore/Query/EntityMaterializerSource.cs @@ -52,7 +52,7 @@ public static readonly MethodInfo TryReadValueMethod [MethodImpl(MethodImplOptions.AggressiveInlining)] private static TValue TryReadValue( in ValueBuffer valueBuffer, int index, IPropertyBase property) - => (TValue)valueBuffer[index]; + => valueBuffer[index] is TValue value ? value : default; public virtual Expression CreateMaterializeExpression( IEntityType entityType, diff --git a/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs index 386bec964c3..295108eb638 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/AsyncSimpleQueryInMemoryTest.cs @@ -13,41 +13,5 @@ public AsyncSimpleQueryInMemoryTest(NorthwindQueryInMemoryFixture Date: Wed, 21 Aug 2019 15:39:50 -0700 Subject: [PATCH 2/3] Add null propagation/protection logic for InMemory provider. When we bind to a non-nullable property on entity that can be nullable (e.g. due to left join) we modify the result to be nullable, to avoid "nullable object must have a value" errors. This nullability is then propagated further. We have few blockers: - predicates (Where, Any, First, Count, etc): always need to be of type bool. When necessary we add "== true". - conditional expression Test: needs to be bool, same as above - method call arguments: we can't reliably rewrite methodcall when the arguments types change (generic methods specifically), we convert arguments back to their original types if they were changed to nullable versions. - method call caller: if the caller was changed from non-nullable to nullable we still need to call the method with the original type, but we add null check before - caller.Method(args) -> nullable_caller == null ? null : (resultType?)caller.Method(args) - selectors (Select, Max etc): we need to preserve the original result type, we use convert - anonymous type, array init: we need to preserve the original type, we use convert Also enable GearsOfWar and ComplexNavigation tests for in memory. --- .../Internal/EntityProjectionExpression.cs | 17 +- ...yExpressionTranslatingExpressionVisitor.cs | 171 +++++++++++++++++- ...emoryProjectionBindingExpressionVisitor.cs | 19 +- .../Query/Internal/InMemoryQueryExpression.cs | 27 ++- ...yableMethodTranslatingExpressionVisitor.cs | 75 ++++++-- .../Query/Internal/InMemoryTableExpression.cs | 8 +- ...SubqueryMemberPushdownExpressionVisitor.cs | 1 + .../ComplexNavigationsQueryInMemoryTest.cs | 170 ++++++++++++++++- .../Query/GearsOfWarQueryInMemoryTest.cs | 104 ++++++++++- .../QueryFilterFuncletizationInMemoryTest.cs | 7 + .../Query/SimpleQueryInMemoryTest.cs | 9 + .../Query/SpatialQueryInMemoryTest.cs | 123 ------------- 12 files changed, 581 insertions(+), 150 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs b/src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs index 2f34586df3f..265765b4152 100644 --- a/src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/EntityProjectionExpression.cs @@ -5,10 +5,11 @@ using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { - public class EntityProjectionExpression : Expression + public class EntityProjectionExpression : Expression, IPrintableExpression { private readonly IDictionary _readExpressionMap; @@ -50,5 +51,19 @@ public virtual Expression BindProperty(IProperty property) return _readExpressionMap[property]; } + + public virtual void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine(nameof(EntityProjectionExpression) + ":"); + using (expressionPrinter.Indent()) + { + foreach (var readExpressionMapEntry in _readExpressionMap) + { + expressionPrinter.Append(readExpressionMapEntry.Key + " -> "); + expressionPrinter.Visit(readExpressionMapEntry.Value); + expressionPrinter.AppendLine(); + } + } + } } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 0505562e8dd..cc54fa1e294 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Linq.Expressions; @@ -69,27 +70,59 @@ public virtual Expression Translate(Expression expression) protected override Expression VisitBinary(BinaryExpression binaryExpression) { - var left = Visit(binaryExpression.Left); - var right = Visit(binaryExpression.Right); - if (left == null || right == null) + var newLeft = Visit(binaryExpression.Left); + var newRight = Visit(binaryExpression.Right); + + if (newLeft == null || newRight == null) { return null; } - return binaryExpression.Update(left, binaryExpression.Conversion, right); + if (TypeNullabilityChanged(newLeft.Type, binaryExpression.Left.Type) + || TypeNullabilityChanged(newRight.Type, binaryExpression.Right.Type)) + { + newLeft = MakeNullable(newLeft); + newRight = MakeNullable(newRight); + } + + return Expression.MakeBinary( + binaryExpression.NodeType, + newLeft, + newRight, + binaryExpression.IsLiftedToNull, + binaryExpression.Method, + binaryExpression.Conversion); } + private static Expression MakeNullable(Expression expression) + => !expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.MakeNullable()) + : expression; + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { var test = Visit(conditionalExpression.Test); var ifTrue = Visit(conditionalExpression.IfTrue); var ifFalse = Visit(conditionalExpression.IfFalse); + if (test == null || ifTrue == null || ifFalse == null) { return null; } - return conditionalExpression.Update(test, ifTrue, ifFalse); + if (test.Type == typeof(bool?)) + { + test = Expression.Equal(test, Expression.Constant(true, typeof(bool?))); + } + + if (TypeNullabilityChanged(ifTrue.Type, conditionalExpression.IfTrue.Type) + || TypeNullabilityChanged(ifFalse.Type, conditionalExpression.IfFalse.Type)) + { + ifTrue = MakeNullable(ifTrue); + ifFalse = MakeNullable(ifFalse); + } + + return Expression.Condition(test, ifTrue, ifFalse); } protected override Expression VisitMember(MemberExpression memberExpression) @@ -109,9 +142,27 @@ protected override Expression VisitMember(MemberExpression memberExpression) return result; } - return memberExpression.Update(innerExpression); + var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); + if (innerExpression != null + && innerExpression.Type.IsNullableType() + && ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name)) + { + updatedMemberExpression = MakeNullable(updatedMemberExpression); + + return Expression.Condition( + Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)), + Expression.Default(updatedMemberExpression.Type), + updatedMemberExpression); + } + + return updatedMemberExpression; } + private bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName) + => !(callerType.IsGenericType + && callerType.GetGenericTypeDefinition() == typeof(Nullable<>) + && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); + private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result) { result = null; @@ -150,7 +201,9 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ } result = BindProperty(entityProjection, property); - if (result.Type != type) + + // if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation + if (result.Type != type && !TypeNullabilityChanged(result.Type, type)) { result = Expression.Convert(result, type); } @@ -161,6 +214,9 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ return false; } + private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType) + => maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType; + private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) { return entityProjectionExpression.BindProperty(property); @@ -330,6 +386,7 @@ MethodInfo getMethod() } var arguments = new Expression[methodCallExpression.Arguments.Count]; + var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); for (var i = 0; i < arguments.Length; i++) { var argument = Visit(methodCallExpression.Arguments[i]); @@ -337,9 +394,37 @@ MethodInfo getMethod() { return null; } + + // if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type, + // so we are forced to cast back to the original type + if (argument.Type != methodCallExpression.Arguments[i].Type + && !parameterTypes[i].IsAssignableFrom(argument.Type)) + { + argument = Expression.Convert(argument, methodCallExpression.Arguments[i].Type); + } + arguments[i] = argument; } + // if object is nullable, add null safeguard before calling the function + // we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard + if (methodCallExpression.Object != null + && @object.Type.IsNullableType() + && !(methodCallExpression.Method.Name == nameof(Nullable.GetValueOrDefault))) + { + var result = (Expression)methodCallExpression.Update( + Expression.Convert(@object, methodCallExpression.Object.Type), + arguments); + + result = MakeNullable(result); + result = Expression.Condition( + Expression.Equal(@object, Expression.Constant(null, @object.Type)), + Expression.Constant(null, result.Type), + result); + + return result; + } + return methodCallExpression.Update(@object, arguments); } @@ -383,6 +468,68 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp return Expression.Constant(false); } + protected override Expression VisitNew(NewExpression newExpression) + { + var newArguments = new List(); + foreach (var argument in newExpression.Arguments) + { + var newArgument = Visit(argument); + if (newArgument.Type != argument.Type) + { + newArgument = Expression.Convert(newArgument, argument.Type); + } + + newArguments.Add(newArgument); + } + + return newExpression.Update(newArguments); + } + + protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) + { + var newExpressions = new List(); + foreach (var expression in newArrayExpression.Expressions) + { + var newExpression = Visit(expression); + if (newExpression.Type != expression.Type) + { + newExpression = Expression.Convert(newExpression, expression.Type); + } + + newExpressions.Add(newExpression); + } + + return newArrayExpression.Update(newExpressions); + } + + protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + { + var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression); + var bindings = new List(); + foreach (var binding in memberInitExpression.Bindings) + { + switch (binding) + { + case MemberAssignment memberAssignment: + var expression = Visit(memberAssignment.Expression); + if (expression.Type != memberAssignment.Expression.Type) + { + expression = Expression.Convert(expression, memberAssignment.Expression.Type); + } + + bindings.Add(Expression.Bind(memberAssignment.Member, expression)); + break; + + default: + // TODO: MemberMemberBinding and MemberListBinding + bindings.Add(binding); + break; + } + } + + return memberInitExpression.Update(newExpression, bindings); + } + protected override Expression VisitExtension(Expression extensionExpression) { switch (extensionExpression) @@ -441,7 +588,15 @@ private static T GetParameterValue(QueryContext queryContext, string paramete protected override Expression VisitUnary(UnaryExpression unaryExpression) { - var result = base.VisitUnary(unaryExpression); + var newOperand = Visit(unaryExpression.Operand); + + if (unaryExpression.NodeType == ExpressionType.Convert + && newOperand.Type == unaryExpression.Type) + { + return newOperand; + } + + var result = (Expression)Expression.MakeUnary(unaryExpression.NodeType, newOperand, unaryExpression.Type); if (result is UnaryExpression outerUnary && outerUnary.NodeType == ExpressionType.Convert && outerUnary.Operand is UnaryExpression innerUnary diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs index ded09cef614..649095403cf 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs @@ -127,9 +127,17 @@ public override Expression Visit(Expression expression) } var translation = _expressionTranslatingExpressionVisitor.Translate(expression); - return translation == null - ? base.Visit(expression) - : new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type); + if (translation == null) + { + return base.Visit(expression); + } + + if (translation.Type != expression.Type) + { + translation = Expression.Convert(translation, expression.Type); + } + + return new ProjectionBindingExpression(_queryExpression, _queryExpression.AddToProjection(translation), expression.Type); } else { @@ -139,6 +147,11 @@ public override Expression Visit(Expression expression) return null; } + if (translation.Type != expression.Type) + { + translation = Expression.Convert(translation, expression.Type); + } + _projectionMapping[_projectionMembers.Peek()] = translation; return new ProjectionBindingExpression(_queryExpression, _projectionMembers.Peek(), expression.Type); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index 5de9275fde9..a198ddf478e 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -13,7 +13,7 @@ namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { - public partial class InMemoryQueryExpression : Expression + public partial class InMemoryQueryExpression : Expression, IPrintableExpression { private static readonly ConstructorInfo _valueBufferConstructor = typeof(ValueBuffer).GetConstructors().Single(ci => ci.GetParameters().Length == 1); @@ -751,6 +751,31 @@ public virtual void AddSelectMany(InMemoryQueryExpression innerQueryExpression, _projectionMapping = projectionMapping; } + public virtual void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine(nameof(InMemoryQueryExpression) + ": "); + using (expressionPrinter.Indent()) + { + expressionPrinter.AppendLine(nameof(ServerQueryExpression) + ": "); + using (expressionPrinter.Indent()) + { + expressionPrinter.Visit(ServerQueryExpression); + } + + expressionPrinter.AppendLine("ProjectionMapping:"); + using (expressionPrinter.Indent()) + { + foreach (var projectionMapping in _projectionMapping) + { + expressionPrinter.Append("Member: " + projectionMapping.Key + " Projection: "); + expressionPrinter.Visit(projectionMapping.Value); + } + } + + expressionPrinter.AppendLine(); + } + } + private class NullableReadValueExpressionVisitor : ExpressionVisitor { protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 1944ed1f7ff..d1d8fad4e53 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -64,7 +64,7 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression source, LambdaExpression predicate) { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - predicate = TranslateLambdaExpression(source, predicate); + predicate = TranslateLambdaExpression(source, predicate, preserveType: true); if (predicate == null) { return null; @@ -93,7 +93,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour } else { - predicate = TranslateLambdaExpression(source, predicate); + predicate = TranslateLambdaExpression(source, predicate, preserveType: true); if (predicate == null) { return null; @@ -131,17 +131,23 @@ protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression s protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item) { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; + var itemType = item.Type; item = TranslateExpression(item); if (item == null) { return null; } + if (item.Type != itemType) + { + item = Expression.Convert(item, itemType); + } + inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(item.Type), + InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(itemType), Expression.Call( - InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, item.Type), + InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, itemType), inMemoryQueryExpression.ServerQueryExpression, Expression.Lambda( inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter)), @@ -165,7 +171,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so } else { - predicate = TranslateLambdaExpression(source, predicate); + predicate = TranslateLambdaExpression(source, predicate, preserveType: true); if (predicate == null) { return null; @@ -343,6 +349,10 @@ protected override ShapedQueryExpression TranslateJoin(ShapedQueryExpression out return null; } + var unifyNullabilityResult = UnifyNullability(outerKeySelector, innerKeySelector); + outerKeySelector = unifyNullabilityResult.lambda1; + innerKeySelector = unifyNullabilityResult.lambda2; + var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, resultSelector.Parameters[1].Type); @@ -360,6 +370,27 @@ protected override ShapedQueryExpression TranslateJoin(ShapedQueryExpression out transparentIdentifierType); } + private (LambdaExpression lambda1, LambdaExpression lambda2) UnifyNullability(LambdaExpression lambda1, LambdaExpression lambda2) + { + if (lambda1.Body.Type != lambda2.Body.Type) + { + if (TypeNullabilityChanged(lambda1.Body.Type, lambda2.Body.Type)) + { + lambda2 = Expression.Lambda(Expression.Convert(lambda2.Body, lambda1.Body.Type), lambda2.Parameters); + } + else if (TypeNullabilityChanged(lambda2.Body.Type, lambda1.Body.Type)) + { + lambda1 = Expression.Lambda(Expression.Convert(lambda1.Body, lambda2.Body.Type), lambda1.Parameters); + } + } + + return (lambda1, lambda2); + } + + // TODO: DRY + private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType) + => maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType; + protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpression source, LambdaExpression predicate, Type returnType, bool returnDefault) { return TranslateSingleResultOperator( @@ -380,6 +411,10 @@ protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression return null; } + var unifyNullabilityResult = UnifyNullability(outerKeySelector, innerKeySelector); + outerKeySelector = unifyNullabilityResult.lambda1; + innerKeySelector = unifyNullabilityResult.lambda2; + var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, resultSelector.Parameters[1].Type); @@ -410,7 +445,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio } else { - predicate = TranslateLambdaExpression(source, predicate); + predicate = TranslateLambdaExpression(source, predicate, preserveType: true); if (predicate == null) { return null; @@ -478,8 +513,8 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); } - var predicate = TranslateLambdaExpression(source, Expression.Lambda(equals, parameter)); - if (predicate == null) + var discriminatorPredicate = TranslateLambdaExpression(source, Expression.Lambda(equals, parameter)); + if (discriminatorPredicate == null) { return null; } @@ -487,7 +522,7 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s inMemoryQueryExpression.ServerQueryExpression = Expression.Call( InMemoryLinqOperatorProvider.Where.MakeGenericMethod(typeof(ValueBuffer)), inMemoryQueryExpression.ServerQueryExpression, - predicate); + discriminatorPredicate); var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression; var projectionMember = projectionBindingExpression.ProjectionMember; @@ -703,7 +738,7 @@ protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression so protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression source, LambdaExpression predicate) { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - predicate = TranslateLambdaExpression(source, predicate); + predicate = TranslateLambdaExpression(source, predicate, preserveType: true); if (predicate == null) { return null; @@ -723,10 +758,21 @@ private Expression TranslateExpression(Expression expression) } private LambdaExpression TranslateLambdaExpression( - ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression) + ShapedQueryExpression shapedQueryExpression, + LambdaExpression lambdaExpression, + bool preserveType = false) { var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression)); + if (lambdaBody != null && preserveType) + { + lambdaBody = lambdaBody.Type == typeof(bool?) + ? Expression.Equal( + lambdaBody, + Expression.Constant(true, typeof(bool?))) + : lambdaBody; + } + return lambdaBody != null ? Expression.Lambda(lambdaBody, ((InMemoryQueryExpression)shapedQueryExpression.QueryExpression).CurrentParameter) @@ -744,6 +790,8 @@ private ShapedQueryExpression TranslateScalarAggregate( { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; + var selectorBodyType = selector?.Body.Type; + selector = selector == null || selector.Body == selector.Parameters[0] ? Expression.Lambda( @@ -756,6 +804,11 @@ private ShapedQueryExpression TranslateScalarAggregate( return null; } + if (selectorBodyType != null && selector.Body.Type != selectorBodyType) + { + selector = Expression.Lambda(Expression.Convert(selector.Body, selectorBodyType), selector.Parameters); + } + MethodInfo getMethod() => methodName switch { diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs index 5318823b6e5..e038c490147 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryTableExpression.cs @@ -5,11 +5,12 @@ using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal { - public class InMemoryTableExpression : Expression + public class InMemoryTableExpression : Expression, IPrintableExpression { public InMemoryTableExpression(IEntityType entityType) { @@ -26,5 +27,10 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) { return this; } + + public virtual void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Append(nameof(InMemoryTableExpression) + ": Entity: " + EntityType.DisplayName()); + } } } diff --git a/src/EFCore/Query/Internal/SubqueryMemberPushdownExpressionVisitor.cs b/src/EFCore/Query/Internal/SubqueryMemberPushdownExpressionVisitor.cs index 06a73286b76..92925ad8628 100644 --- a/src/EFCore/Query/Internal/SubqueryMemberPushdownExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/SubqueryMemberPushdownExpressionVisitor.cs @@ -52,6 +52,7 @@ protected override Expression VisitMember(MemberExpression memberExpression) (target, nullable) => { var memberAccessExpression = Expression.MakeMemberAccess(target, memberExpression.Member); + return nullable && !memberAccessExpression.Type.IsNullableType() ? Expression.Convert(memberAccessExpression, memberAccessExpression.Type.MakeNullable()) : (Expression)memberAccessExpression; diff --git a/test/EFCore.InMemory.FunctionalTests/Query/ComplexNavigationsQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/ComplexNavigationsQueryInMemoryTest.cs index 86b164eb948..53cce45a98f 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/ComplexNavigationsQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/ComplexNavigationsQueryInMemoryTest.cs @@ -7,12 +7,180 @@ namespace Microsoft.EntityFrameworkCore.Query { - internal class ComplexNavigationsQueryInMemoryTest : ComplexNavigationsQueryTestBase + public class ComplexNavigationsQueryInMemoryTest : ComplexNavigationsQueryTestBase { public ComplexNavigationsQueryInMemoryTest(ComplexNavigationsQueryInMemoryFixture fixture, ITestOutputHelper testOutputHelper) : base(fixture) { //TestLoggerFactory.TestOutputHelper = testOutputHelper; } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task Complex_SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_with_other_query_operators_composed_on_top(bool isAsync) + { + return base.Complex_SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_with_other_query_operators_composed_on_top(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task Multiple_SelectMany_with_navigation_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.Multiple_SelectMany_with_navigation_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task Multiple_SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_joined_together(bool isAsync) + { + return base.Multiple_SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_joined_together(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_navigation_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.SelectMany_with_navigation_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_navigation_filter_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.SelectMany_with_navigation_filter_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_navigation_filter_paging_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.SelectMany_with_navigation_filter_paging_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_followed_by_Select_required_navigation_using_different_navs(bool isAsync) + { + return base.SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_followed_by_Select_required_navigation_using_different_navs(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_followed_by_Select_required_navigation_using_same_navs(bool isAsync) + { + return base.SelectMany_with_nested_navigations_and_explicit_DefaultIfEmpty_followed_by_Select_required_navigation_using_same_navs(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany(bool isAsync) + { + return base.SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany2(bool isAsync) + { + return base.SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany2(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany3(bool isAsync) + { + return base.SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany3(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany4(bool isAsync) + { + return base.SelectMany_with_nested_navigations_explicit_DefaultIfEmpty_and_additional_joins_outside_of_SelectMany4(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigation_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.SelectMany_with_nested_navigation_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_navigation_filter_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.SelectMany_with_nested_navigation_filter_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //DefaultIfEmpty + public override Task SelectMany_with_nested_required_navigation_filter_and_explicit_DefaultIfEmpty(bool isAsync) + { + return base.SelectMany_with_nested_required_navigation_filter_and_explicit_DefaultIfEmpty(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //GroupBy + public override Task Simple_level1_level2_GroupBy_Count(bool isAsync) + { + return base.Simple_level1_level2_GroupBy_Count(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] //GroupBy + public override Task Simple_level1_level2_GroupBy_Having_Count(bool isAsync) + { + return base.Simple_level1_level2_GroupBy_Having_Count(isAsync); + } + + [ConditionalTheory(Skip = "issue #17386")] + public override Task Complex_query_with_optional_navigations_and_client_side_evaluation(bool isAsync) + { + return base.Complex_query_with_optional_navigations_and_client_side_evaluation(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task Project_collection_navigation_nested(bool isAsync) + { + return base.Project_collection_navigation_nested(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task Project_collection_navigation_nested_anonymous(bool isAsync) + { + return base.Project_collection_navigation_nested_anonymous(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task Project_collection_navigation_using_ef_property(bool isAsync) + { + return base.Project_collection_navigation_using_ef_property(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task Project_navigation_and_collection(bool isAsync) + { + return base.Project_navigation_and_collection(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task SelectMany_nested_navigation_property_optional_and_projection(bool isAsync) + { + return base.SelectMany_nested_navigation_property_optional_and_projection(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task SelectMany_nested_navigation_property_required(bool isAsync) + { + return base.SelectMany_nested_navigation_property_required(isAsync); + } + + [ConditionalTheory(Skip = "issue #17460")] + public override Task Where_complex_predicate_with_with_nav_prop_and_OrElse4(bool isAsync) + { + return base.Where_complex_predicate_with_with_nav_prop_and_OrElse4(isAsync); + } + + [ConditionalTheory(Skip = "issue #17460")] + public override Task Join_flattening_bug_4539(bool isAsync) + { + return base.Join_flattening_bug_4539(isAsync); + } + + [ConditionalTheory(Skip = "issue #17463")] + public override Task Include18_3_3(bool isAsync) + { + return base.Include18_3_3(isAsync); + } + + [ConditionalFact(Skip = "issue #17463")] + public override void Include19() + { + base.Include19(); + } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs index 75f037b9809..050d95bd736 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs @@ -7,12 +7,114 @@ namespace Microsoft.EntityFrameworkCore.Query { - internal class GearsOfWarQueryInMemoryTest : GearsOfWarQueryTestBase + public class GearsOfWarQueryInMemoryTest : GearsOfWarQueryTestBase { public GearsOfWarQueryInMemoryTest(GearsOfWarQueryInMemoryFixture fixture, ITestOutputHelper testOutputHelper) : base(fixture) { //TestLoggerFactory.TestOutputHelper = testOutputHelper; } + + [ConditionalTheory(Skip = "issue #16963")] // groupby + public override Task GroupBy_Property_Include_Aggregate_with_anonymous_selector(bool isAsync) + { + return base.GroupBy_Property_Include_Aggregate_with_anonymous_selector(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] // groupby + public override Task GroupBy_Property_Include_Select_Count(bool isAsync) + { + return base.GroupBy_Property_Include_Select_Count(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] // groupby + public override Task GroupBy_Property_Include_Select_LongCount(bool isAsync) + { + return base.GroupBy_Property_Include_Select_LongCount(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] // groupby + public override Task GroupBy_Property_Include_Select_Max(bool isAsync) + { + return base.GroupBy_Property_Include_Select_Max(isAsync); + } + + [ConditionalTheory(Skip = "issue #16963")] // groupby + public override Task GroupBy_Property_Include_Select_Min(bool isAsync) + { + return base.GroupBy_Property_Include_Select_Min(isAsync); + } + + [ConditionalTheory(Skip = "issue #17386")] + public override Task Correlated_collection_order_by_constant_null_of_non_mapped_type(bool isAsync) + { + return base.Correlated_collection_order_by_constant_null_of_non_mapped_type(isAsync); + } + + [ConditionalTheory(Skip = "issue #17386")] + public override Task Client_side_equality_with_parameter_works_with_optional_navigations(bool isAsync) + { + return base.Client_side_equality_with_parameter_works_with_optional_navigations(isAsync); + } + + [ConditionalTheory(Skip = "issue #17386")] + public override Task Where_coalesce_with_anonymous_types(bool isAsync) + { + return base.Where_coalesce_with_anonymous_types(isAsync); + } + + [ConditionalTheory(Skip = "issue #17386")] + public override Task Where_conditional_with_anonymous_type(bool isAsync) + { + return base.Where_conditional_with_anonymous_type(isAsync); + } + + [ConditionalTheory(Skip = "issue #17386")] + public override Task GetValueOrDefault_on_DateTimeOffset(bool isAsync) + { + return base.GetValueOrDefault_on_DateTimeOffset(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task Correlated_collection_with_complex_OrderBy(bool isAsync) + { + return base.Correlated_collection_with_complex_OrderBy(isAsync); + } + + [ConditionalTheory(Skip = "issue #17453")] + public override Task Correlated_collection_with_very_complex_order_by(bool isAsync) + { + return base.Correlated_collection_with_very_complex_order_by(isAsync); + } + + [ConditionalTheory(Skip = "issue #17463")] + public override Task Include_collection_OrderBy_aggregate(bool isAsync) + { + return base.Include_collection_OrderBy_aggregate(isAsync); + } + + [ConditionalTheory(Skip = "issue #17463")] + public override Task Include_collection_with_complex_OrderBy3(bool isAsync) + { + return base.Include_collection_with_complex_OrderBy3(isAsync); + } + + [ConditionalTheory(Skip = "issue #17463")] + public override void Include_on_GroupJoin_SelectMany_DefaultIfEmpty_with_coalesce_result1() + { + base.Include_on_GroupJoin_SelectMany_DefaultIfEmpty_with_coalesce_result1(); + } + + [ConditionalTheory(Skip = "issue #17463")] + public override void Include_on_GroupJoin_SelectMany_DefaultIfEmpty_with_coalesce_result2() + { + base.Include_on_GroupJoin_SelectMany_DefaultIfEmpty_with_coalesce_result2(); + } + + [ConditionalTheory(Skip = "issue #16963")] //length + public override Task Null_semantics_is_correctly_applied_for_function_comparisons_that_take_arguments_from_optional_navigation_complex(bool isAsync) + { + return base.Null_semantics_is_correctly_applied_for_function_comparisons_that_take_arguments_from_optional_navigation_complex(isAsync); + } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/QueryFilterFuncletizationInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/QueryFilterFuncletizationInMemoryTest.cs index d4a85fa9f7f..f2ed8b793de 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/QueryFilterFuncletizationInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/QueryFilterFuncletizationInMemoryTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.EntityFrameworkCore.TestUtilities; +using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -19,5 +20,11 @@ public class QueryFilterFuncletizationInMemoryFixture : QueryFilterFuncletizatio { protected override ITestStoreFactory TestStoreFactory => InMemoryTestStoreFactory.Instance; } + + [ConditionalFact(Skip = "issue #17386")] + public override void DbContext_list_is_parameterized() + { + base.DbContext_list_is_parameterized(); + } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs index b643a6ffa45..ee8f6df16b0 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs @@ -2,7 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Dynamic; +using System.Linq; using System.Runtime.InteropServices.WindowsRuntime; +using System.Text.RegularExpressions; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; @@ -201,5 +204,11 @@ public override Task Default_if_empty_top_level_projection(bool isAsync) } #endregion + + [ConditionalTheory(Skip = "issue #17386")] + public override Task Where_equals_on_null_nullable_int_types(bool isAsync) + { + return base.Where_equals_on_null_nullable_int_types(isAsync); + } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/SpatialQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/SpatialQueryInMemoryTest.cs index a8a7f1c5219..cbdbfeaba0d 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/SpatialQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/SpatialQueryInMemoryTest.cs @@ -1,9 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Threading.Tasks; -using Xunit; - namespace Microsoft.EntityFrameworkCore.Query { public class SpatialQueryInMemoryTest : SpatialQueryTestBase @@ -12,125 +9,5 @@ public SpatialQueryInMemoryTest(SpatialQueryInMemoryFixture fixture) : base(fixture) { } - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Area(bool isAsync) - => base.Area(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Boundary(bool isAsync) - => base.Boundary(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Centroid(bool isAsync) - => base.Centroid(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Dimension(bool isAsync) - => base.Dimension(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task EndPoint(bool isAsync) - => base.EndPoint(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Envelope(bool isAsync) - => base.Envelope(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task ExteriorRing(bool isAsync) - => base.ExteriorRing(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task GeometryType(bool isAsync) - => base.GeometryType(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task ICurve_IsClosed(bool isAsync) - => base.ICurve_IsClosed(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task IGeometryCollection_Count(bool isAsync) - => base.IGeometryCollection_Count(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task IMultiCurve_IsClosed(bool isAsync) - => base.IMultiCurve_IsClosed(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task InteriorPoint(bool isAsync) - => base.InteriorPoint(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task IsEmpty(bool isAsync) - => base.IsEmpty(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task IsRing(bool isAsync) - => base.IsRing(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task IsSimple(bool isAsync) - => base.IsSimple(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task IsValid(bool isAsync) - => base.IsValid(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Length(bool isAsync) - => base.Length(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task LineString_Count(bool isAsync) - => base.LineString_Count(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task M(bool isAsync) - => base.M(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task NumGeometries(bool isAsync) - => base.NumGeometries(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task NumInteriorRings(bool isAsync) - => base.NumInteriorRings(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task NumPoints(bool isAsync) - => base.NumPoints(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task OgcGeometryType(bool isAsync) - => base.OgcGeometryType(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task PointOnSurface(bool isAsync) - => base.PointOnSurface(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task SRID(bool isAsync) - => base.SRID(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task SRID_geometry(bool isAsync) - => base.SRID_geometry(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task StartPoint(bool isAsync) - => base.StartPoint(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task X(bool isAsync) - => base.X(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Y(bool isAsync) - => base.Y(isAsync); - - [ConditionalTheory(Skip = "Issue #16963. Nullable error")] - public override Task Z(bool isAsync) - => base.Z(isAsync); } } From 25d30bf083f25883577501f01f19d1b6e466855c Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Wed, 28 Aug 2019 19:03:06 -0700 Subject: [PATCH 3/3] InMemory: Some refactorings Part of #16963 --- ...yExpressionTranslatingExpressionVisitor.cs | 105 ++++++++---------- .../Internal/InMemoryLinqOperatorProvider.cs | 15 ++- .../Query/Internal/InMemoryQueryExpression.cs | 6 +- ...yableMethodTranslatingExpressionVisitor.cs | 82 ++++++-------- 4 files changed, 98 insertions(+), 110 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index cc54fa1e294..714b4d01dcf 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -78,11 +78,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return null; } - if (TypeNullabilityChanged(newLeft.Type, binaryExpression.Left.Type) - || TypeNullabilityChanged(newRight.Type, binaryExpression.Right.Type)) + if (IsConvertedToNullable(newLeft, binaryExpression.Left) + || IsConvertedToNullable(newRight, binaryExpression.Right)) { - newLeft = MakeNullable(newLeft); - newRight = MakeNullable(newRight); + newLeft = ConvertToNullable(newLeft); + newRight = ConvertToNullable(newRight); } return Expression.MakeBinary( @@ -94,11 +94,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) binaryExpression.Conversion); } - private static Expression MakeNullable(Expression expression) - => !expression.Type.IsNullableType() - ? Expression.Convert(expression, expression.Type.MakeNullable()) - : expression; - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { var test = Visit(conditionalExpression.Test); @@ -115,11 +110,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional test = Expression.Equal(test, Expression.Constant(true, typeof(bool?))); } - if (TypeNullabilityChanged(ifTrue.Type, conditionalExpression.IfTrue.Type) - || TypeNullabilityChanged(ifFalse.Type, conditionalExpression.IfFalse.Type)) + if (IsConvertedToNullable(ifTrue, conditionalExpression.IfTrue) + || IsConvertedToNullable(ifFalse, conditionalExpression.IfFalse)) { - ifTrue = MakeNullable(ifTrue); - ifFalse = MakeNullable(ifFalse); + ifTrue = ConvertToNullable(ifTrue); + ifFalse = ConvertToNullable(ifFalse); } return Expression.Condition(test, ifTrue, ifFalse); @@ -142,12 +137,17 @@ protected override Expression VisitMember(MemberExpression memberExpression) return result; } + static bool shouldApplyNullProtectionForMemberAccess(Type callerType, string memberName) + => !(callerType.IsGenericType + && callerType.GetGenericTypeDefinition() == typeof(Nullable<>) + && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); + var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); if (innerExpression != null && innerExpression.Type.IsNullableType() - && ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name)) + && shouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name)) { - updatedMemberExpression = MakeNullable(updatedMemberExpression); + updatedMemberExpression = ConvertToNullable(updatedMemberExpression); return Expression.Condition( Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)), @@ -158,11 +158,6 @@ protected override Expression VisitMember(MemberExpression memberExpression) return updatedMemberExpression; } - private bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName) - => !(callerType.IsGenericType - && callerType.GetGenericTypeDefinition() == typeof(Nullable<>) - && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); - private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result) { result = null; @@ -203,7 +198,10 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ result = BindProperty(entityProjection, property); // if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation - if (result.Type != type && !TypeNullabilityChanged(result.Type, type)) + if (result.Type != type + && !(result.Type.IsNullableType() + && !type.IsNullableType() + && result.Type.UnwrapNullableType() == type)) { result = Expression.Convert(result, type); } @@ -214,13 +212,23 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ return false; } - private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType) - => maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType; + private static bool IsConvertedToNullable(Expression result, Expression original) + => result.Type.IsNullableType() + && !original.Type.IsNullableType() + && result.Type.UnwrapNullableType() == original.Type; + + private static Expression ConvertToNullable(Expression expression) + => !expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.MakeNullable()) + : expression; + + private static Expression ConvertToNonNullable(Expression expression) + => expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) + : expression; private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) - { - return entityProjectionExpression.BindProperty(property); - } + => entityProjectionExpression.BindProperty(property); private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) { @@ -386,7 +394,7 @@ MethodInfo getMethod() } var arguments = new Expression[methodCallExpression.Arguments.Count]; - var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); + var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); for (var i = 0; i < arguments.Length; i++) { var argument = Visit(methodCallExpression.Arguments[i]); @@ -397,10 +405,10 @@ MethodInfo getMethod() // if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type, // so we are forced to cast back to the original type - if (argument.Type != methodCallExpression.Arguments[i].Type + if (IsConvertedToNullable(argument, methodCallExpression.Arguments[i]) && !parameterTypes[i].IsAssignableFrom(argument.Type)) { - argument = Expression.Convert(argument, methodCallExpression.Arguments[i].Type); + argument = ConvertToNonNullable(argument); } arguments[i] = argument; @@ -412,11 +420,11 @@ MethodInfo getMethod() && @object.Type.IsNullableType() && !(methodCallExpression.Method.Name == nameof(Nullable.GetValueOrDefault))) { - var result = (Expression)methodCallExpression.Update( + var result = (Expression)methodCallExpression.Update( Expression.Convert(@object, methodCallExpression.Object.Type), arguments); - result = MakeNullable(result); + result = ConvertToNullable(result); result = Expression.Condition( Expression.Equal(@object, Expression.Constant(null, @object.Type)), Expression.Constant(null, result.Type), @@ -474,9 +482,9 @@ protected override Expression VisitNew(NewExpression newExpression) foreach (var argument in newExpression.Arguments) { var newArgument = Visit(argument); - if (newArgument.Type != argument.Type) + if (IsConvertedToNullable(newArgument, argument)) { - newArgument = Expression.Convert(newArgument, argument.Type); + newArgument = ConvertToNonNullable(newArgument); } newArguments.Add(newArgument); @@ -491,9 +499,9 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio foreach (var expression in newArrayExpression.Expressions) { var newExpression = Visit(expression); - if (newExpression.Type != expression.Type) + if (IsConvertedToNullable(newExpression, expression)) { - newExpression = Expression.Convert(newExpression, expression.Type); + newExpression = ConvertToNonNullable(newExpression); } newExpressions.Add(newExpression); @@ -502,32 +510,15 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio return newArrayExpression.Update(newExpressions); } - protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) { - var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression); - var bindings = new List(); - foreach (var binding in memberInitExpression.Bindings) + var expression = Visit(memberAssignment.Expression); + if (IsConvertedToNullable(expression, memberAssignment.Expression)) { - switch (binding) - { - case MemberAssignment memberAssignment: - var expression = Visit(memberAssignment.Expression); - if (expression.Type != memberAssignment.Expression.Type) - { - expression = Expression.Convert(expression, memberAssignment.Expression.Type); - } - - bindings.Add(Expression.Bind(memberAssignment.Member, expression)); - break; - - default: - // TODO: MemberMemberBinding and MemberListBinding - bindings.Add(binding); - break; - } + expression = ConvertToNonNullable(expression); } - return memberInitExpression.Update(newExpression, bindings); + return memberAssignment.Update(expression); } protected override Expression VisitExtension(Expression extensionExpression) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs b/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs index 3d7f53aee47..34b561f941a 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs @@ -108,9 +108,20 @@ public static MethodInfo GetMinWithSelector(Type type) private static Dictionary SumWithoutSelectorMethods { get; } private static Dictionary SumWithSelectorMethods { get; } - private static bool IsFunc(Type type, int funcGenericArgs = 2) + private static Type GetFuncType(int funcGenericArguments) + { + return funcGenericArguments switch + { + 1 => typeof(Func<>), + 2 => typeof(Func<,>), + 3 => typeof(Func<,,>), + 4 => typeof(Func<,,,>), + _ => throw new InvalidOperationException("Invalid number of arguments for Func"), + }; + } + private static bool IsFunc(Type type, int funcGenericArguments = 2) => type.IsGenericType - && type.GetGenericArguments().Length == funcGenericArgs; + && type.GetGenericTypeDefinition() == GetFuncType(funcGenericArguments); static InMemoryLinqOperatorProvider() { diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index a198ddf478e..d6b8c8fb669 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -51,8 +51,8 @@ public InMemoryQueryExpression(IEntityType entityType) Constant(property.GetIndex()), MakeMemberAccess(_valueBufferParameter, _valueBufferCountMemberInfo)), - CreateReadValueExpression(typeof(object), property.GetIndex(), property), - Default(typeof(object))); + CreateReadValueExpression(property.ClrType, property.GetIndex(), property), + Default(property.ClrType)); } var entityProjection = new EntityProjectionExpression(entityType, readExpressionMap); @@ -256,7 +256,7 @@ public virtual void ApplyDefaultIfEmpty() { if (_valueBufferSlots.Count != 0) { - throw new InvalidOperationException("Cannot Apply DefaultIfEmpty after ClientProjection."); + throw new InvalidOperationException("Cannot apply DefaultIfEmpty after a client-evaluated projection."); } var result = new Dictionary(); diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index d1d8fad4e53..39e405bddba 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -6,8 +6,6 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; -using Microsoft.EntityFrameworkCore.Diagnostics; -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query; @@ -131,23 +129,17 @@ protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression s protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item) { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - var itemType = item.Type; - item = TranslateExpression(item); + item = TranslateExpression(item, preserveType: true); if (item == null) { return null; } - if (item.Type != itemType) - { - item = Expression.Convert(item, itemType); - } - inMemoryQueryExpression.ServerQueryExpression = Expression.Call( - InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(itemType), + InMemoryLinqOperatorProvider.Contains.MakeGenericMethod(item.Type), Expression.Call( - InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, itemType), + InMemoryLinqOperatorProvider.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, item.Type), inMemoryQueryExpression.ServerQueryExpression, Expression.Lambda( inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter)), @@ -349,9 +341,7 @@ protected override ShapedQueryExpression TranslateJoin(ShapedQueryExpression out return null; } - var unifyNullabilityResult = UnifyNullability(outerKeySelector, innerKeySelector); - outerKeySelector = unifyNullabilityResult.lambda1; - innerKeySelector = unifyNullabilityResult.lambda2; + (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, @@ -370,27 +360,31 @@ protected override ShapedQueryExpression TranslateJoin(ShapedQueryExpression out transparentIdentifierType); } - private (LambdaExpression lambda1, LambdaExpression lambda2) UnifyNullability(LambdaExpression lambda1, LambdaExpression lambda2) + private (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) + AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector) { - if (lambda1.Body.Type != lambda2.Body.Type) + static bool isConvertedToNullable(Expression outer, Expression inner) + => outer.Type.IsNullableType() + && !inner.Type.IsNullableType() + && outer.Type.UnwrapNullableType() == inner.Type; + + if (outerKeySelector.Body.Type != innerKeySelector.Body.Type) { - if (TypeNullabilityChanged(lambda1.Body.Type, lambda2.Body.Type)) + if (isConvertedToNullable(outerKeySelector.Body, innerKeySelector.Body)) { - lambda2 = Expression.Lambda(Expression.Convert(lambda2.Body, lambda1.Body.Type), lambda2.Parameters); + innerKeySelector = Expression.Lambda( + Expression.Convert(innerKeySelector.Body, outerKeySelector.Body.Type), innerKeySelector.Parameters); } - else if (TypeNullabilityChanged(lambda2.Body.Type, lambda1.Body.Type)) + else if (isConvertedToNullable(innerKeySelector.Body, outerKeySelector.Body)) { - lambda1 = Expression.Lambda(Expression.Convert(lambda1.Body, lambda2.Body.Type), lambda1.Parameters); + outerKeySelector = Expression.Lambda( + Expression.Convert(outerKeySelector.Body, innerKeySelector.Body.Type), outerKeySelector.Parameters); } } - return (lambda1, lambda2); + return (outerKeySelector, innerKeySelector); } - // TODO: DRY - private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType) - => maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType; - protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpression source, LambdaExpression predicate, Type returnType, bool returnDefault) { return TranslateSingleResultOperator( @@ -411,9 +405,7 @@ protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression return null; } - var unifyNullabilityResult = UnifyNullability(outerKeySelector, innerKeySelector); - outerKeySelector = unifyNullabilityResult.lambda1; - innerKeySelector = unifyNullabilityResult.lambda2; + (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); var transparentIdentifierType = TransparentIdentifierFactory.Create( resultSelector.Parameters[0].Type, @@ -752,9 +744,19 @@ protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression so return source; } - private Expression TranslateExpression(Expression expression) + private Expression TranslateExpression(Expression expression, bool preserveType = false) { - return _expressionTranslator.Translate(expression); + var result = _expressionTranslator.Translate(expression); + + if (expression != null && result != null + && preserveType && expression.Type != result.Type) + { + result = expression.Type == typeof(bool) + ? Expression.Equal(result, Expression.Constant(true, result.Type)) + : (Expression)Expression.Convert(result, expression.Type); + } + + return result; } private LambdaExpression TranslateLambdaExpression( @@ -762,16 +764,7 @@ private LambdaExpression TranslateLambdaExpression( LambdaExpression lambdaExpression, bool preserveType = false) { - var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression)); - - if (lambdaBody != null && preserveType) - { - lambdaBody = lambdaBody.Type == typeof(bool?) - ? Expression.Equal( - lambdaBody, - Expression.Constant(true, typeof(bool?))) - : lambdaBody; - } + var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression), preserveType); return lambdaBody != null ? Expression.Lambda(lambdaBody, @@ -790,25 +783,18 @@ private ShapedQueryExpression TranslateScalarAggregate( { var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - var selectorBodyType = selector?.Body.Type; - selector = selector == null || selector.Body == selector.Parameters[0] ? Expression.Lambda( inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter) - : TranslateLambdaExpression(source, selector); + : TranslateLambdaExpression(source, selector, preserveType: true); if (selector == null) { return null; } - if (selectorBodyType != null && selector.Body.Type != selectorBodyType) - { - selector = Expression.Lambda(Expression.Convert(selector.Body, selectorBodyType), selector.Parameters); - } - MethodInfo getMethod() => methodName switch {