From aa0440e014f126a2218eafa4213048a0284d715f Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 5 Aug 2019 15:45:08 +0200 Subject: [PATCH] Fix nullability in entity equality Fixes #16564 --- ...ntityEqualityRewritingExpressionVisitor.cs | 57 ++++++++++++++----- .../Query/SimpleQueryInMemoryTest.cs | 24 -------- ...impleQuerySqlServerTest.ResultOperators.cs | 6 +- .../Query/SimpleQuerySqlServerTest.cs | 4 +- 4 files changed, 47 insertions(+), 44 deletions(-) diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index 57c01ca12bb..e9bbf9e48fc 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -378,7 +378,9 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType); var lambda = Expression.Lambda( Expression.Call( - _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + (keyProperty.ClrType.IsValueType + ? _parameterListValueTypeExtractor + : _parameterListRefTypeExtractor).MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), QueryCompilationContext.QueryContextParameter, Expression.Constant(listParam.Name, typeof(string)), Expression.Constant(keyProperty, typeof(IProperty))), @@ -394,7 +396,7 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method var param = Expression.Parameter(entityType.ClrType, "v"); var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param); rewrittenSource = Expression.Call( - QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()), Unwrap(newSource), Expression.Quote(keySelector)); } @@ -407,7 +409,7 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method return Expression.Call( (Unwrap(newSource).Type.IsQueryableType() ? QueryableMethodProvider.ContainsMethodInfo - : _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType), + : _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType.MakeNullable()), rewrittenSource, rewrittenItem ); @@ -466,7 +468,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method var orderingMethodInfo = GetOrderingMethodInfo(firstOrdering, isAscending); expression = Expression.Call( - orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()), expression, Expression.Quote(rewrittenKeySelector) ); @@ -753,7 +755,7 @@ private Expression RewriteNullEquality( // (this is also why we can do it even over a subquery with a composite key) return Expression.MakeBinary( equality ? ExpressionType.Equal : ExpressionType.NotEqual, - CreatePropertyAccessExpression(nonNullExpression, keyProperties[0], makeNullable: true), + CreatePropertyAccessExpression(nonNullExpression, keyProperties[0]), Expression.Constant(null)); } @@ -853,12 +855,12 @@ protected virtual Expression CreateKeyAccessExpression( .Cast() .ToArray())); - private Expression CreatePropertyAccessExpression(Expression target, IProperty property, bool makeNullable = false) + private Expression CreatePropertyAccessExpression(Expression target, IProperty property) { // The target is a constant - evaluate the property immediately and return the result if (target is ConstantExpression constantExpression) { - return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType); + return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType.MakeNullable()); } // If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new @@ -877,10 +879,10 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p QueryCompilationContext.QueryContextParameter); var newParameterName = $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}"; - return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, property.ClrType); + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, property.ClrType.MakeNullable()); } - return target.CreateEFPropertyExpression(property, makeNullable); + return target.CreateEFPropertyExpression(property, true); } private static object ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) @@ -898,7 +900,8 @@ private static readonly MethodInfo _parameterValueExtractor /// Extracts the list parameter with name from and returns a /// projection to its elements' values. /// - private static object ParameterListValueExtractor(QueryContext context, string baseParameterName, IProperty property) + private static List ParameterListRefTypeExtractor(QueryContext context, string baseParameterName, IProperty property) + where TProperty : class { Debug.Assert(property.ClrType == typeof(TProperty)); @@ -912,10 +915,34 @@ private static object ParameterListValueExtractor(QueryConte return baseListParameter.Select(e => (TProperty)getter.GetClrValue(e)).ToList(); } - private static readonly MethodInfo _parameterListValueExtractor + private static readonly MethodInfo _parameterListRefTypeExtractor = typeof(EntityEqualityRewritingExpressionVisitor) .GetTypeInfo() - .GetDeclaredMethod(nameof(ParameterListValueExtractor)); + .GetDeclaredMethod(nameof(ParameterListRefTypeExtractor)); + + /// + /// Extracts the list parameter with name from and returns a + /// projection to its elements' values. + /// + private static List ParameterListValueTypeExtractor(QueryContext context, string baseParameterName, IProperty property) + where TProperty : struct + { + Debug.Assert(property.ClrType == typeof(TProperty)); + + var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable; + if (baseListParameter == null) + { + return null; + } + + var getter = property.GetGetter(); + return baseListParameter.Select(e => (TProperty?)getter.GetClrValue(e)).ToList(); + } + + private static readonly MethodInfo _parameterListValueTypeExtractor + = typeof(EntityEqualityRewritingExpressionVisitor) + .GetTypeInfo() + .GetDeclaredMethod(nameof(ParameterListValueTypeExtractor)); protected static Expression UnwrapLastNavigation(Expression expression) => (expression as MemberExpression)?.Expression @@ -1063,16 +1090,16 @@ public virtual void Print(ExpressionPrinter expressionPrinter) if (IsEntityType) { - expressionPrinter.StringBuilder.Append($".EntityType({EntityType})"); + expressionPrinter.Append($".EntityType({EntityType})"); } else if (IsDtoType) { - expressionPrinter.StringBuilder.Append(".DTO"); + expressionPrinter.Append(".DTO"); } if (SubqueryTraversed) { - expressionPrinter.StringBuilder.Append(".SubqueryTraversed"); + expressionPrinter.Append(".SubqueryTraversed"); } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs index 6fcc725f12f..738e32be31d 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs @@ -229,12 +229,6 @@ public override Task Union_over_different_projection_types(bool isAsync, string #endregion - [ConditionalFact(Skip = "Issue#16564")] - public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() - { - base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality(); - } - [ConditionalTheory(Skip = "Issue#15711")] public override Task Include_with_orderby_skip_preserves_ordering(bool isAsync) { @@ -265,24 +259,6 @@ public override void Select_nested_collection_multi_level6() base.Select_nested_collection_multi_level6(); } - [ConditionalTheory(Skip = "Issue#16564")] - public override Task Where_subquery_FirstOrDefault_compared_to_entity(bool isAsync) - { - return base.Where_subquery_FirstOrDefault_compared_to_entity(isAsync); - } - - [ConditionalTheory(Skip = "Issue#16564")] - public override Task Where_query_composition_entity_equality_one_element_FirstOrDefault(bool isAsync) - { - return base.Where_query_composition_entity_equality_one_element_FirstOrDefault(isAsync); - } - - [ConditionalTheory(Skip = "Issue#16564")] - public override Task Where_query_composition_entity_equality_no_elements_FirstOrDefault(bool isAsync) - { - return base.Where_query_composition_entity_equality_no_elements_FirstOrDefault(isAsync); - } - [ConditionalTheory(Skip = "Issue#16575")] public override Task Project_single_element_from_collection_with_OrderBy_Distinct_and_FirstOrDefault_followed_by_projecting_length(bool isAsync) { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs index 98d01801aaa..3840e074a8c 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs @@ -1176,7 +1176,7 @@ public override void Contains_over_entityType_should_rewrite_to_identity_equalit FROM [Orders] AS [o] WHERE [o].[OrderID] = 10248", // - @"@__entity_equality_p_0_OrderID='10248' + @"@__entity_equality_p_0_OrderID='10248' (Nullable = true) SELECT CASE WHEN @__entity_equality_p_0_OrderID IN ( @@ -1194,7 +1194,7 @@ public override async Task List_Contains_over_entityType_should_rewrite_to_ident await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync); AssertSql( - @"@__entity_equality_someOrder_0_OrderID='10248' + @"@__entity_equality_someOrder_0_OrderID='10248' (Nullable = true) SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] @@ -1230,7 +1230,7 @@ public override void Contains_over_entityType_with_null_should_rewrite_to_identi base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality(); AssertSql( - @"@__entity_equality_p_0_OrderID='' (Nullable = false) (DbType = Int32) + @"@__entity_equality_p_0_OrderID='' (DbType = Int32) SELECT CASE WHEN @__entity_equality_p_0_OrderID IN ( diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 3febfbc98d3..a41408aa061 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -187,8 +187,8 @@ public override async Task Entity_equality_local_composite_key(bool isAsync) await base.Entity_equality_local_composite_key(isAsync); AssertSql( - @"@__entity_equality_local_0_OrderID='10248' -@__entity_equality_local_0_ProductID='11' + @"@__entity_equality_local_0_OrderID='10248' (Nullable = true) +@__entity_equality_local_0_ProductID='11' (Nullable = true) SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice] FROM [Order Details] AS [o]