From 6dbf690b2678711a8f0ffce441e1ad857f35c875 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Fri, 24 May 2019 12:57:50 +0200 Subject: [PATCH] Reimplement entity equality New implementation of entity equality rewriter, which follows exact entity (and anonymous) from roots to the comparison expression, for more precise rewriting. Closes #15588 --- src/EFCore/Diagnostics/CoreEventId.cs | 2 +- .../Diagnostics/CoreLoggerExtensions.cs | 14 +- .../Diagnostics/NavigationPathEventData.cs | 38 - .../Internal/ExpressionExtensions.cs | 29 - src/EFCore/Properties/CoreStrings.Designer.cs | 8 + src/EFCore/Properties/CoreStrings.resx | 5 +- .../Internal/NullConditionalExpression.cs | 13 +- ...ntityEqualityRewritingExpressionVisitor.cs | 739 ++++++++++++++++++ .../QueryOptimizingExpressionVisitor.cs | 1 + ...yableMethodTranslatingExpressionVisitor.cs | 2 +- .../Query/FromSqlQueryTestBase.cs | 17 + .../Query/GearsOfWarQueryTestBase.cs | 16 +- .../Query/IncludeTestBase.cs | 2 +- .../SimpleQueryTestBase.ResultOperators.cs | 2 +- .../Query/SimpleQueryTestBase.Where.cs | 2 +- .../Query/SimpleQueryTestBase.cs | 140 +++- .../Query/DbFunctionsSqlServerTest.cs | 2 +- .../Query/SimpleQuerySqlServerTest.cs | 64 ++ 18 files changed, 976 insertions(+), 120 deletions(-) delete mode 100644 src/EFCore/Diagnostics/NavigationPathEventData.cs create mode 100644 src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs diff --git a/src/EFCore/Diagnostics/CoreEventId.cs b/src/EFCore/Diagnostics/CoreEventId.cs index 86efb20facf..e67fbb5196b 100644 --- a/src/EFCore/Diagnostics/CoreEventId.cs +++ b/src/EFCore/Diagnostics/CoreEventId.cs @@ -236,7 +236,7 @@ private enum Id /// This event is in the category. /// /// - /// This event uses the payload when used with a . + /// This event uses the payload when used with a . /// /// public static readonly EventId PossibleUnintendedCollectionNavigationNullComparisonWarning diff --git a/src/EFCore/Diagnostics/CoreLoggerExtensions.cs b/src/EFCore/Diagnostics/CoreLoggerExtensions.cs index 017cf92ca5f..796c8951569 100644 --- a/src/EFCore/Diagnostics/CoreLoggerExtensions.cs +++ b/src/EFCore/Diagnostics/CoreLoggerExtensions.cs @@ -503,10 +503,10 @@ public static void SensitiveDataLoggingEnabledWarning( /// Logs for the event. /// /// The diagnostics logger to use. - /// The navigation properties being used. + /// The navigation being used. public static void PossibleUnintendedCollectionNavigationNullComparisonWarning( [NotNull] this IDiagnosticsLogger diagnostics, - [NotNull] IReadOnlyList navigationPath) + [NotNull] INavigation navigation) { var definition = CoreResources.LogPossibleUnintendedCollectionNavigationNullComparison(diagnostics); @@ -516,25 +516,25 @@ public static void PossibleUnintendedCollectionNavigationNullComparisonWarning( definition.Log( diagnostics, warningBehavior, - string.Join(".", navigationPath.Select(p => p.Name))); + $"{navigation.DeclaringEntityType.Name}.{navigation.GetTargetType().Name}"); } if (diagnostics.DiagnosticSource.IsEnabled(definition.EventId.Name)) { diagnostics.DiagnosticSource.Write( definition.EventId.Name, - new NavigationPathEventData( + new NavigationEventData( definition, PossibleUnintendedCollectionNavigationNullComparisonWarning, - navigationPath)); + navigation)); } } private static string PossibleUnintendedCollectionNavigationNullComparisonWarning(EventDefinitionBase definition, EventData payload) { var d = (EventDefinition)definition; - var p = (NavigationPathEventData)payload; - return d.GenerateMessage(string.Join(".", p.NavigationPath.Select(pb => pb.Name))); + var p = (NavigationEventData)payload; + return d.GenerateMessage($"{p.Navigation.DeclaringEntityType.Name}.{p.Navigation.GetTargetType().Name}"); } /// diff --git a/src/EFCore/Diagnostics/NavigationPathEventData.cs b/src/EFCore/Diagnostics/NavigationPathEventData.cs deleted file mode 100644 index ef0bdf2fc26..00000000000 --- a/src/EFCore/Diagnostics/NavigationPathEventData.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Metadata; - -namespace Microsoft.EntityFrameworkCore.Diagnostics -{ - /// - /// A event payload class for events that have - /// a navigation property. - /// - public class NavigationPathEventData : EventData - { - /// - /// Constructs the event payload. - /// - /// The event definition. - /// A delegate that generates a log message for this event. - /// The navigation property. - public NavigationPathEventData( - [NotNull] EventDefinitionBase eventDefinition, - [NotNull] Func messageGenerator, - [NotNull] IReadOnlyCollection navigationPath) - : base(eventDefinition, messageGenerator) - { - NavigationPath = navigationPath; - } - - /// - /// The navigation property. - /// - public virtual IReadOnlyCollection NavigationPath { get; } - } -} diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index d8fa43e9fe2..59e9226ab90 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -373,35 +373,6 @@ public static bool IsNullPropagationCandidate( return true; } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public static Expression CreateKeyAccessExpression( - [NotNull] this Expression target, - [NotNull] IReadOnlyList properties) - { - Check.NotNull(target, nameof(target)); - Check.NotNull(properties, nameof(properties)); - - return properties.Count == 1 - ? target.CreateEFPropertyExpression(properties[0]) - : Expression.New( - AnonymousObject.AnonymousObjectCtor, - Expression.NewArrayInit( - typeof(object), - properties - .Select( - p => - Expression.Convert( - target.CreateEFPropertyExpression(p), - typeof(object))) - .Cast() - .ToArray())); - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 555f6a51b34..db904865064 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2100,6 +2100,14 @@ public static string NoNavigation([CanBeNull] object entityType, [CanBeNull] obj GetString("NoNavigation", nameof(entityType), nameof(foreignKey)), entityType, foreignKey); + /// + /// This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery. + /// + public static string SubqueryWithCompositeKeyNotSupported([CanBeNull] object entityType) + => string.Format( + GetString("SubqueryWithCompositeKeyNotSupported", nameof(entityType)), + entityType); + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name); diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index fc76207d650..cf27e4c9445 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1168,4 +1168,7 @@ There is no navigation on entity type '{entityType}' associated with the foreign key {foreignKey}. - \ No newline at end of file + + This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery. + + diff --git a/src/EFCore/Query/Expressions/Internal/NullConditionalExpression.cs b/src/EFCore/Query/Expressions/Internal/NullConditionalExpression.cs index 48706c2c116..6f6579f8194 100644 --- a/src/EFCore/Query/Expressions/Internal/NullConditionalExpression.cs +++ b/src/EFCore/Query/Expressions/Internal/NullConditionalExpression.cs @@ -110,16 +110,13 @@ var operation /// /// An instance of . protected override Expression VisitChildren(ExpressionVisitor visitor) - { - var newCaller = visitor.Visit(Caller); - var newAccessOperation = visitor.Visit(AccessOperation); + => Update(visitor.Visit(Caller), visitor.Visit(AccessOperation)); - return newCaller != Caller - || newAccessOperation != AccessOperation - && !(ExpressionEqualityComparer.Instance.Equals((newAccessOperation as NullConditionalExpression)?.AccessOperation, AccessOperation)) + public virtual Expression Update(Expression newCaller, Expression newAccessOperation) + => newCaller != Caller || newAccessOperation != AccessOperation + && !ExpressionEqualityComparer.Instance.Equals((newAccessOperation as NullConditionalExpression)?.AccessOperation, AccessOperation) ? new NullConditionalExpression(newCaller, newAccessOperation) - : (this); - } + : this; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs new file mode 100644 index 00000000000..4620a386d92 --- /dev/null +++ b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs @@ -0,0 +1,739 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Extensions.Internal; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query.Expressions.Internal; +using Microsoft.EntityFrameworkCore.Query.Internal; +using Microsoft.EntityFrameworkCore.Query.NavigationExpansion; + +namespace Microsoft.EntityFrameworkCore.Query.Pipeline +{ + /// + /// Rewrites comparisons of entities (as opposed to comparisons of their properties) into comparison of their keys. + /// + /// + /// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id). + /// + public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor + { + protected IDiagnosticsLogger Logger { get; } + protected IModel Model { get; } + + private static readonly MethodInfo _objectEqualsMethodInfo + = typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) }); + + public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext) + { + Model = queryCompilationContext.Model; + Logger = queryCompilationContext.Logger; + } + + public Expression Rewrite(Expression expression) => Unwrap(Visit(expression)); + + protected override Expression VisitConstant(ConstantExpression constantExpression) + => constantExpression.IsEntityQueryable() + ? new EntityReferenceExpression(constantExpression, Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType)) + : (Expression)constantExpression; + + protected override Expression VisitNew(NewExpression newExpression) + { + var visitedArgs = Visit(newExpression.Arguments); + var visitedExpression = newExpression.Update(visitedArgs.Select(Unwrap)); + + return (newExpression.Members?.Count ?? 0) == 0 + ? (Expression)visitedExpression + : new EntityReferenceExpression(visitedExpression, visitedExpression.Members + .Select((m, i) => (Member: m, Index: i)) + .ToDictionary( + mi => mi.Member.Name, + mi => visitedArgs[mi.Index])); + } + + protected override Expression VisitMember(MemberExpression memberExpression) + { + var visitedExpression = base.Visit(memberExpression.Expression); + var visitedMemberExpression = memberExpression.Update(Unwrap(visitedExpression)); + return visitedExpression is EntityReferenceExpression entityWrapper + ? entityWrapper.TraverseProperty(memberExpression.Member.Name, visitedMemberExpression) + : visitedMemberExpression; + } + + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + var (newLeft, newRight) = (Visit(binaryExpression.Left), Visit(binaryExpression.Right)); + if (binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual) + { + if (RewriteEquality(binaryExpression.NodeType == ExpressionType.Equal, newLeft, newRight) is Expression result) + { + return result; + } + } + + return binaryExpression.Update(Unwrap(newLeft), binaryExpression.Conversion, Unwrap(newRight)); + } + + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + // This is needed for Convert but is generalized + var newOperand = Visit(unaryExpression.Operand); + var newUnary = unaryExpression.Update(Unwrap(newOperand)); + return newOperand is EntityReferenceExpression entityWrapper + ? entityWrapper.Update(newUnary) + : (Expression)newUnary; + } + + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) + { + // This is for "x is y" + var visitedExpression = Visit(typeBinaryExpression.Expression); + var visitedTypeBinary= typeBinaryExpression.Update(Unwrap(visitedExpression)); + return visitedExpression is EntityReferenceExpression entityWrapper + ? entityWrapper.Update(visitedTypeBinary) + : (Expression)visitedTypeBinary; + } + + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) + { + var newTest = Visit(conditionalExpression.Test); + var newIfTrue = Visit(conditionalExpression.IfTrue); + var newIfFalse = Visit(conditionalExpression.IfFalse); + + var newConditional = conditionalExpression.Update(newTest, Unwrap(newIfTrue), Unwrap(newIfFalse)); + + // TODO: the true and false sides may refer different entity types which happen to have the same + // CLR type (e.g. shared entities) + var wrapper = newIfTrue as EntityReferenceExpression ?? newIfFalse as EntityReferenceExpression; + + return wrapper == null ? (Expression)newConditional : wrapper.Update(newConditional); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + var arguments = methodCallExpression.Arguments; + Expression newSource; + + // Check if this is this Equals() + if (methodCallExpression.Method.Name == nameof(object.Equals) + && methodCallExpression.Object != null + && methodCallExpression.Arguments.Count == 1) + { + var (newLeft, newRight) = (Visit(methodCallExpression.Object), Visit(arguments[0])); + return RewriteEquality(true, newLeft, newRight) + ?? methodCallExpression.Update(Unwrap(newLeft), new[] { Unwrap(newRight) }); + } + + if (methodCallExpression.Method.Equals(_objectEqualsMethodInfo)) + { + var (newLeft, newRight) = (Visit(arguments[0]), Visit(arguments[1])); + return RewriteEquality(true, newLeft, newRight) + ?? methodCallExpression.Update(null, new[] { Unwrap(newLeft), Unwrap(newRight) }); + } + + // Navigation via EF.Property() or via an indexer property + if (methodCallExpression.TryGetEFPropertyArguments(out _, out var propertyName) + || methodCallExpression.TryGetEFIndexerArguments(out _, out propertyName)) + { + newSource = Visit(arguments[0]); + var newMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource), arguments[1] }); + return newSource is EntityReferenceExpression entityWrapper + ? entityWrapper.TraverseProperty(propertyName, newMethodCall) + : newMethodCall; + } + + if (methodCallExpression.Method.DeclaringType == typeof(Queryable) + || methodCallExpression.Method.DeclaringType == typeof(Enumerable) + || methodCallExpression.Method.DeclaringType == typeof(EntityQueryableExtensions)) + { + switch (methodCallExpression.Method.Name) + { + // The following are projecting methods, which flow the entity type from *within* the lambda outside. + // These are handled by dedicated methods + case nameof(Queryable.Select): + case nameof(Queryable.SelectMany): + return VisitSelectMethodCall(methodCallExpression); + + case nameof(Queryable.GroupJoin): + case nameof(Queryable.Join): + case nameof(EntityQueryableExtensions.LeftJoin): + return VisitJoinMethodCall(methodCallExpression); + + case nameof(Queryable.GroupBy): // TODO: Implement + break; + } + } + + // TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type. + // Do this here, since below we visit the arguments (avoid double visitation) + + if (arguments.Count == 0) + { + return methodCallExpression.Update( + Unwrap(Visit(methodCallExpression.Object)), Array.Empty()); + } + + // Methods with a typed first argument (source), and with no lambda arguments or a single lambda + // argument that has one parameter are rewritten automatically (e.g. Where(), FromSql(), Average() + var newArguments = new Expression[arguments.Count]; + var lambdaArgs = arguments.Select(GetLambdaOrNull).Where(l => l != null).ToArray(); + newSource = Visit(arguments[0]); + newArguments[0] = Unwrap(newSource); + if (methodCallExpression.Object == null + && newSource is EntityReferenceExpression newSourceWrapper + && (lambdaArgs.Length == 0 + || lambdaArgs.Length == 1 && lambdaArgs[0].Parameters.Count == 1)) + { + for (var i = 1; i < arguments.Count; i++) + { + // Visit all arguments, rewriting the single lambda to replace its parameter expression + newArguments[i] = GetLambdaOrNull(arguments[i]) is LambdaExpression lambda + ? Unwrap(RewriteAndVisitLambda(lambda, newSourceWrapper)) + : Unwrap(Visit(arguments[i])); + } + + var sourceParamType = methodCallExpression.Method.GetParameters()[0].ParameterType; + var sourceElementType = sourceParamType.TryGetSequenceType(); + if (sourceElementType != null + || sourceParamType == typeof(IQueryable)) // OfType + { + // If the method returns the element same type as the source, flow the type information + // (e.g. Where, OrderBy) + if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType + && (returnElementType == sourceElementType || sourceElementType == null)) + { + return newSourceWrapper.Update( + methodCallExpression.Update(null, newArguments)); + } + + // If the source type is an IQueryable over the return type, this is a cardinality-reducing method (e.g. First). + // These don't flow the last navigation. In addition, these will be translated into a subquery, and we should not + // perform entity equality rewriting if the entity type has a composite key. + if (methodCallExpression.Method.ReturnType == sourceElementType) + { + return new EntityReferenceExpression( + methodCallExpression.Update(null, newArguments), + newSourceWrapper.EntityType, + lastNavigation: null, + newSourceWrapper.AnonymousType, + subqueryTraversed: true); + } + } + + // Method does not flow entity type (e.g. Average) + return methodCallExpression.Update(null, newArguments); + } + + // Unknown method - still need to visit all arguments + for (var i = 1; i < arguments.Count; i++) + { + newArguments[i] = Unwrap(Visit(arguments[i])); + } + + return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments); + } + + protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression) + { + var arguments = methodCallExpression.Arguments; + var newSource = Visit(arguments[0]); + + if (!(newSource is EntityReferenceExpression sourceWrapper)) + { + return arguments.Count == 2 + ? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) }) + : arguments.Count == 3 + ? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])), Unwrap(Visit(arguments[2])) }) + : throw new NotSupportedException(); + } + + MethodCallExpression newMethodCall; + + if (arguments.Count == 2) + { + var selector = arguments[1].UnwrapQuote(); + var newSelector = RewriteAndVisitLambda(selector, sourceWrapper); + + newMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newSelector) }); + return newSelector.Body is EntityReferenceExpression entityWrapper + ? entityWrapper.Update(newMethodCall) + : (Expression)newMethodCall; + } + + if (arguments.Count == 3) + { + var collectionSelector = arguments[1].UnwrapQuote(); + var newCollectionSelector = RewriteAndVisitLambda(collectionSelector, sourceWrapper); + + var resultSelector = arguments[2].UnwrapQuote(); + var newResultSelector = newCollectionSelector.Body is EntityReferenceExpression newCollectionSelectorWrapper + ? RewriteAndVisitLambda(resultSelector, sourceWrapper, newCollectionSelectorWrapper) + : (LambdaExpression)Visit(resultSelector); + + newMethodCall = methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newCollectionSelector), Unwrap(newResultSelector) }); + return newResultSelector.Body is EntityReferenceExpression entityWrapper + ? entityWrapper.Update(newMethodCall) + : (Expression)newMethodCall; + } + + throw new NotSupportedException(); + } + + protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression) + { + var arguments = methodCallExpression.Arguments; + + if (arguments.Count != 5) + { + return base.VisitMethodCall(methodCallExpression); + } + + var newOuter = Visit(arguments[0]); + var newInner = Visit(arguments[1]); + var outerKeySelector = arguments[2].UnwrapQuote(); + var innerKeySelector = arguments[3].UnwrapQuote(); + var resultSelector = arguments[4].UnwrapQuote(); + + if (!(newOuter is EntityReferenceExpression outerWrapper && newInner is EntityReferenceExpression innerWrapper)) + { + return methodCallExpression.Update(null, new[] + { + Unwrap(newOuter), Unwrap(newInner), Unwrap(Visit(outerKeySelector)), Unwrap(Visit(innerKeySelector)), Unwrap(Visit(resultSelector)) + }); + } + + var newOuterKeySelector = RewriteAndVisitLambda(outerKeySelector, outerWrapper); + var newInnerKeySelector = RewriteAndVisitLambda(innerKeySelector, innerWrapper); + var newResultSelector = RewriteAndVisitLambda(resultSelector, outerWrapper, innerWrapper); + + MethodCallExpression newMethodCall; + + // If both outer and inner key selectors project to the same entity type, that's an entity equality + // we need to rewrite. + if (newOuterKeySelector.Body is EntityReferenceExpression outerKeySelectorWrapper + && newInnerKeySelector.Body is EntityReferenceExpression innerKeySelectorWrapper + && outerKeySelectorWrapper.IsEntityType && innerKeySelectorWrapper.IsEntityType + && outerKeySelectorWrapper.EntityType.RootType() == innerKeySelectorWrapper.EntityType.RootType()) + { + var entityType = outerKeySelectorWrapper.EntityType; + var keyProperties = entityType.FindPrimaryKey().Properties; + + if (keyProperties.Count > 1 + && (outerKeySelectorWrapper.SubqueryTraversed || innerKeySelectorWrapper.SubqueryTraversed)) + { + // One side of the comparison is the result of a subquery, and we have a composite key. + // Rewriting this would mean evaluating the subquery more than once, so we don't do it. + throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + // Rewrite the lambda bodies, adding the key access on top of whatever is there, and then + // produce a new MethodInfo and MethodCallExpression + var origGenericArguments = methodCallExpression.Method.GetGenericArguments(); + + var outerKeyAccessExpression = CreateKeyAccessExpression(Unwrap(outerKeySelectorWrapper), keyProperties); + var outerKeySelectorType = typeof(Func<,>).MakeGenericType(origGenericArguments[0], outerKeyAccessExpression.Type); + newOuterKeySelector = Expression.Lambda( + outerKeySelectorType, + outerKeyAccessExpression, + newOuterKeySelector.TailCall, + newOuterKeySelector.Parameters); + + var innerKeyAccessExpression = CreateKeyAccessExpression(Unwrap(innerKeySelectorWrapper), keyProperties); + var innerKeySelectorType = typeof(Func<,>).MakeGenericType(origGenericArguments[1], innerKeyAccessExpression.Type); + newInnerKeySelector = Expression.Lambda( + innerKeySelectorType, + innerKeyAccessExpression, + newInnerKeySelector.TailCall, + newInnerKeySelector.Parameters); + + var newMethod = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod( + origGenericArguments[0], origGenericArguments[1], outerKeyAccessExpression.Type, origGenericArguments[3]); + + newMethodCall = Expression.Call( + newMethod, + Unwrap(newOuter), Unwrap(newInner), + newOuterKeySelector, newInnerKeySelector, + Unwrap(newResultSelector)); + } + else + { + newMethodCall = methodCallExpression.Update(null, new[] + { + Unwrap(newOuter), Unwrap(newInner), Unwrap(newOuterKeySelector), Unwrap(newInnerKeySelector), Unwrap(newResultSelector) + }); + } + + return newResultSelector.Body is EntityReferenceExpression wrapper + ? wrapper.Update(newMethodCall) + : (Expression)newMethodCall; + } + + /// + /// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits + /// the lambda's body. + /// + protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source) + => Expression.Lambda( + lambda.Type, + Visit(ReplacingExpressionVisitor.Replace( + lambda.Parameters.Single(), + source.Update(lambda.Parameters.Single()), + lambda.Body)), + lambda.TailCall, + lambda.Parameters); + + /// + /// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits + /// the lambda's body. + /// + protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, + EntityReferenceExpression source1, + EntityReferenceExpression source2) + => Expression.Lambda( + lambda.Type, + Visit(new ReplacingExpressionVisitor( + new Dictionary + { + { lambda.Parameters[0], source1.Update(lambda.Parameters[0]) }, + { lambda.Parameters[1], source2.Update(lambda.Parameters[1]) } + }).Visit(lambda.Body)), + lambda.TailCall, + lambda.Parameters); + + /// + /// Receives already-visited left and right operands of an equality expression and applies entity equality rewriting to them, + /// if possible. + /// + /// The rewritten entity equality expression, or null if rewriting could not occur for some reason. + protected virtual Expression RewriteEquality(bool equality, Expression left, Expression right) + { + // TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model. + // TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on. + + var leftTypeWrapper = left as EntityReferenceExpression; + var rightTypeWrapper = right as EntityReferenceExpression; + + // If one of the sides is an anonymous object, or both sides are unknown, abort + if (leftTypeWrapper == null && rightTypeWrapper == null + || leftTypeWrapper?.IsAnonymousType == true + || rightTypeWrapper?.IsAnonymousType == true) + { + return null; + } + + // Handle null constants + if (left.IsNullConstantExpression()) + { + if (right.IsNullConstantExpression()) + { + return equality ? Expression.Constant(true) : Expression.Constant(false); + } + + return rightTypeWrapper?.IsEntityType == true + ? RewriteNullEquality(equality, rightTypeWrapper.EntityType, rightTypeWrapper.Underlying, rightTypeWrapper.LastNavigation) + : null; + } + + if (right.IsNullConstantExpression()) + { + return leftTypeWrapper?.IsEntityType == true + ? RewriteNullEquality(equality, leftTypeWrapper.EntityType, leftTypeWrapper.Underlying, leftTypeWrapper.LastNavigation) + : null; + } + + if (leftTypeWrapper != null + && rightTypeWrapper != null + && leftTypeWrapper.EntityType.RootType() != rightTypeWrapper.EntityType.RootType()) + { + return Expression.Constant(!equality); + } + + // One side of the comparison may have an unknown entity type (closure parameter, inline instantiation) + var entityType = (leftTypeWrapper ?? rightTypeWrapper).EntityType; + + return RewriteEntityEquality( + equality, entityType, + Unwrap(left), leftTypeWrapper?.LastNavigation, + Unwrap(right), rightTypeWrapper?.LastNavigation, + leftTypeWrapper?.SubqueryTraversed == true || rightTypeWrapper?.SubqueryTraversed == true); + } + + private Expression RewriteNullEquality( + bool equality, + [NotNull] IEntityType entityType, + [NotNull] Expression nonNullExpression, + [CanBeNull] INavigation lastNavigation) + { + if (lastNavigation?.IsCollection() == true) + { + // collection navigation is only null if its parent entity is null (null propagation thru navigation) + // it is probable that user wanted to see if the collection is (not) empty + // log warning suggesting to use Any() instead. + Logger.PossibleUnintendedCollectionNavigationNullComparisonWarning(lastNavigation); + return RewriteNullEquality(equality, lastNavigation.DeclaringEntityType, UnwrapLastNavigation(nonNullExpression), null); + } + + var keyProperties = entityType.FindPrimaryKey().Properties; + + // TODO: bring back foreign key comparison optimization (#15826) + + // When comparing an entity to null, it's sufficient to simply compare its first primary key column to null. + // (this is also why we can do it even over a subquery with a composite key) + return Expression.MakeBinary( + equality ? ExpressionType.Equal : ExpressionType.NotEqual, + nonNullExpression.CreateEFPropertyExpression(keyProperties[0]), + Expression.Constant(null)); + } + + private Expression RewriteEntityEquality( + bool equality, + [NotNull] IEntityType entityType, + [NotNull] Expression left, [CanBeNull] INavigation leftNavigation, + [NotNull] Expression right, [CanBeNull] INavigation rightNavigation, + bool subqueryTraversed) + { + if (leftNavigation?.IsCollection() == true || rightNavigation?.IsCollection() == true) + { + if (leftNavigation?.Equals(rightNavigation) == true) + { + // Log a warning that comparing 2 collections causes reference comparison + Logger.PossibleUnintendedReferenceComparisonWarning(left, right); + return RewriteEntityEquality( + equality, leftNavigation.DeclaringEntityType, + UnwrapLastNavigation(left), null, + UnwrapLastNavigation(right), null, + subqueryTraversed); + } + + return Expression.Constant(!equality); + } + + var keyProperties = entityType.FindPrimaryKey().Properties; + + if (subqueryTraversed && keyProperties.Count > 1) + { + // One side of the comparison is the result of a subquery, and we have a composite key. + // Rewriting this would mean evaluating the subquery more than once, so we don't do it. + throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); + } + + return Expression.MakeBinary( + equality ? ExpressionType.Equal : ExpressionType.NotEqual, + CreateKeyAccessExpression(Unwrap(left), keyProperties), + CreateKeyAccessExpression(Unwrap(right), keyProperties)); + } + + protected override Expression VisitExtension(Expression expression) + { + switch (expression) + { + case EntityReferenceExpression _: + // If the expression is an EntityReferenceExpression, simply returns it as all rewriting has already occurred. + // This is necessary when traversing wrapping expressions that have been injected into the lambda for parameters. + return expression; + + case NullConditionalExpression nullConditionalExpression: + return VisitNullConditional(nullConditionalExpression); + + default: + return base.VisitExtension(expression); + } + } + + protected virtual Expression VisitNullConditional(NullConditionalExpression expression) + { + var newCaller = Visit(expression.Caller); + var newAccessOperation = Visit(expression.AccessOperation); + var visitedExpression = expression.Update(Unwrap(newCaller), Unwrap(newAccessOperation)); + + // TODO: Can the access operation be anything else than a MemberExpression? + return newCaller is EntityReferenceExpression wrapper + && expression.AccessOperation is MemberExpression memberExpression + ? wrapper.TraverseProperty(memberExpression.Member.Name, visitedExpression) + : visitedExpression; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + // TODO: DRY with NavigationExpansionHelpers + protected static Expression CreateKeyAccessExpression( + [NotNull] Expression target, + [NotNull] IReadOnlyList properties) + => properties.Count == 1 + ? target.CreateEFPropertyExpression(properties[0]) + : Expression.New( + AnonymousObject.AnonymousObjectCtor, + Expression.NewArrayInit( + typeof(object), + properties + .Select( + p => + Expression.Convert( + target.CreateEFPropertyExpression(p), + typeof(object))) + .Cast() + .ToArray())); + + + protected static Expression UnwrapLastNavigation(Expression expression) + => (expression as MemberExpression)?.Expression + ?? (expression is MethodCallExpression methodCallExpression + && methodCallExpression.IsEFProperty() + ? methodCallExpression.Arguments[0] + : null); + + protected static LambdaExpression GetLambdaOrNull(Expression expression) + => expression is LambdaExpression lambda + ? lambda + : expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote + ? (LambdaExpression)unary.Operand + : null; + + protected static Expression Unwrap(Expression expression) + => expression switch { + EntityReferenceExpression wrapper => wrapper.Underlying, + LambdaExpression lambda when lambda.Body is EntityReferenceExpression wrapper => + Expression.Lambda( + lambda.Type, + wrapper.Underlying, + lambda.TailCall, + lambda.Parameters), + _ => expression + }; + + public class EntityReferenceExpression : Expression + { + public override ExpressionType NodeType => ExpressionType.Extension; + + /// + /// The underlying expression being wrapped. + /// + [NotNull] + public Expression Underlying { get; } + + public override Type Type => Underlying.Type; + + [CanBeNull] + public IEntityType EntityType { get; } + + [CanBeNull] + public INavigation LastNavigation => EntityType == null ? null : _lastNavigation; + + [CanBeNull] + private readonly INavigation _lastNavigation; + + [CanBeNull] + public Dictionary AnonymousType { get; } + + public bool SubqueryTraversed { get; } + + public bool IsAnonymousType => AnonymousType != null; + public bool IsEntityType => EntityType != null; + + public EntityReferenceExpression(Expression underlying, Dictionary anonymousType) + { + Underlying = underlying; + AnonymousType = anonymousType; + } + + public EntityReferenceExpression(Expression underlying, IEntityType entityType) + : this(underlying, entityType, null, false) + { + } + + private EntityReferenceExpression(Expression underlying, IEntityType entityType, INavigation lastNavigation, bool subqueryTraversed) + { + Underlying = underlying; + EntityType = entityType; + _lastNavigation = lastNavigation; + SubqueryTraversed = subqueryTraversed; + } + + public EntityReferenceExpression( + Expression underlying, + IEntityType entityType, + INavigation lastNavigation, + Dictionary anonymousType, + bool subqueryTraversed) + { + Underlying = underlying; + EntityType = entityType; + _lastNavigation = lastNavigation; + AnonymousType = anonymousType; + SubqueryTraversed = subqueryTraversed; + } + + /// + /// Attempts to find as a navigation from the current node, + /// and if successful, returns a new wrapping the + /// given expression. Otherwise returns the given expression without wrapping it. + /// + public virtual Expression TraverseProperty(string propertyName, Expression destinationExpression) + { + if (IsEntityType) + { + return EntityType.FindNavigation(propertyName) is INavigation navigation + ? new EntityReferenceExpression( + destinationExpression, + navigation.GetTargetType(), + navigation, + SubqueryTraversed) + : destinationExpression; + } + + if (IsAnonymousType) + { + if (AnonymousType.TryGetValue(propertyName, out var expression) + && expression is EntityReferenceExpression wrapper) + { + return wrapper.IsEntityType + ? new EntityReferenceExpression(destinationExpression, wrapper.EntityType) + : new EntityReferenceExpression(destinationExpression, wrapper.AnonymousType); + } + + return destinationExpression; + } + + throw new NotSupportedException("Unknown type info"); + } + + public EntityReferenceExpression Update(Expression newUnderlying) + => new EntityReferenceExpression(newUnderlying, EntityType, _lastNavigation, AnonymousType, SubqueryTraversed); + + protected override Expression VisitChildren(ExpressionVisitor visitor) + => Update(visitor.Visit(Underlying)); + + public virtual void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Visit(Underlying); + + if (IsEntityType) + { + expressionPrinter.StringBuilder.Append($".EntityType({EntityType})"); + } + else if (IsAnonymousType) + { + expressionPrinter.StringBuilder.Append(".AnonymousObject"); + } + + if (SubqueryTraversed) + { + expressionPrinter.StringBuilder.Append(".SubqueryTraversed"); + } + } + + public override string ToString() => $"{Underlying}[{(IsEntityType ? EntityType.ShortName(): "AnonymousObject")}{(SubqueryTraversed ? ", Subquery" : "")}]"; + } + } +} diff --git a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs index 27c47c4540a..859f7f18a0c 100644 --- a/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/QueryOptimizingExpressionVisitor.cs @@ -22,6 +22,7 @@ public Expression Visit(Expression query) query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query); query = new GroupJoinFlatteningExpressionVisitor().Visit(query); query = new NullCheckRemovingExpressionVisitor().Visit(query); + query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Rewrite(query); query = new NavigationExpander(_queryCompilationContext.Model).ExpandNavigations(query); query = new EnumerableToQueryableReMappingExpressionVisitor().Visit(query); query = new QueryMetadataExtractingExpressionVisitor(_queryCompilationContext).Visit(query); diff --git a/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs index 1b930a2dbfb..d2494338945 100644 --- a/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs @@ -430,7 +430,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } } - throw new NotImplementedException(); + throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name); } // TODO: Skip ToOrderedQueryable method. See Issue#15591 diff --git a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs index a8c78a48e78..a9d86536f8a 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/FromSqlQueryTestBase.cs @@ -955,6 +955,23 @@ public virtual void FromSqlRaw_does_not_parameterize_interpolated_string() } } + [Fact(Skip = "#15855")] + public virtual void Entity_equality_through_fromsql() + { + using (var context = CreateContext()) + { + var actual = context.Set() + .FromSqlRaw(NormalizeDelimetersInRawString("SELECT * FROM [Orders]")) + .Where(o => o.Customer == new Customer + { + CustomerID = "VINET" + }) + .ToArray(); + + Assert.Equal(1, actual.Length); + } + } + protected string NormalizeDelimetersInRawString(string sql) => Fixture.TestStore.NormalizeDelimetersInRawString(sql); diff --git a/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs index 46a43d68ae9..1546e65a3f5 100644 --- a/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs @@ -6062,7 +6062,7 @@ public virtual Task Negated_bool_ternary_inside_anonymous_type_in_projection(boo elementSorter: e => e.c); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory(Skip = "issue #15848")] [MemberData(nameof(IsAsyncData))] public virtual Task Order_by_entity_qsre(bool isAsync) { @@ -6084,7 +6084,7 @@ public virtual Task Order_by_entity_qsre_with_inheritance(bool isAsync) assertOrder: true); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory(Skip = "issue #15848")] [MemberData(nameof(IsAsyncData))] public virtual Task Order_by_entity_qsre_composite_key(bool isAsync) { @@ -6096,7 +6096,7 @@ public virtual Task Order_by_entity_qsre_composite_key(bool isAsync) assertOrder: true); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory(Skip = "issue #15848")] [MemberData(nameof(IsAsyncData))] public virtual Task Order_by_entity_qsre_with_other_orderbys(bool isAsync) { @@ -6112,7 +6112,7 @@ public virtual Task Order_by_entity_qsre_with_other_orderbys(bool isAsync) assertOrder: true); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Join_on_entity_qsre_keys(bool isAsync) { @@ -6128,7 +6128,7 @@ join w2 in ws on w1 equals w2 elementSorter: e => e.Name1 + " " + e.Name2); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Join_on_entity_qsre_keys_composite_key(bool isAsync) { @@ -6144,7 +6144,7 @@ join g2 in gs on g1 equals g2 elementSorter: e => e.GearName1 + " " + e.GearName2); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Join_on_entity_qsre_keys_inheritance(bool isAsync) { @@ -6160,7 +6160,7 @@ join o in gs.OfType() on g equals o elementSorter: e => e.GearName + " " + e.OfficerName); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Join_on_entity_qsre_keys_outer_key_is_navigation(bool isAsync) { @@ -6210,7 +6210,7 @@ join t in ts.Where(tt => tt.Note == "Cole's Tag" || tt.Note == "Dom's Tag") on g elementSorter: e => e.Nickname + " " + e.Note); } - [ConditionalTheory(Skip = "Issue #15588")] + [ConditionalTheory(Skip = "#15946")] [MemberData(nameof(IsAsyncData))] public virtual Task Join_on_entity_qsre_keys_inner_key_is_nested_navigation(bool isAsync) { diff --git a/test/EFCore.Specification.Tests/Query/IncludeTestBase.cs b/test/EFCore.Specification.Tests/Query/IncludeTestBase.cs index e83f75cc08a..4726835d6fc 100644 --- a/test/EFCore.Specification.Tests/Query/IncludeTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/IncludeTestBase.cs @@ -4099,7 +4099,7 @@ public virtual async Task Include_empty_collection_sets_IsLoaded(bool useString, } } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15949")] [InlineData(false, false)] [InlineData(true, false)] public virtual async Task Include_empty_reference_sets_IsLoaded(bool useString, bool async) diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs index 6cb1c57c02d..503b28d979d 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs @@ -1736,7 +1736,7 @@ public virtual Task OrderBy_Skip_Last_gives_correct_result(bool isAsync) entryCount: 1); } - [ConditionalFact(Skip = "Issue#15588")] + [ConditionalFact(Skip = "#15939")] public virtual void Contains_over_entityType_should_rewrite_to_identity_equality() { using (var context = CreateContext()) diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs index 0fc047db667..752734aa47a 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.Where.cs @@ -1517,7 +1517,7 @@ await AssertQuery( entryCount: 1); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual async Task Where_poco_closure(bool isAsync) { diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 414bda1c396..49b396e165c 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -364,7 +364,7 @@ from c in cs select c.CustomerID); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual Task Entity_equality_local(bool isAsync) { @@ -402,7 +402,7 @@ from c2 in cs select c2, o => o, i => i, (o, i) => o).Select(e => e.CustomerID)); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual Task Entity_equality_local_inline(bool isAsync) { @@ -417,7 +417,22 @@ from c in cs select c.CustomerID); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_local_inline_composite_key(bool isAsync) + => AssertQuery( + isAsync, + odt => + from od in odt + where od == new OrderDetail + { + OrderID = 10248, + ProductID = 11 + } + select od, + entryCount: 1); + + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Entity_equality_null(bool isAsync) { @@ -429,7 +444,17 @@ from c in cs select c.CustomerID); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_null_composite_key(bool isAsync) + => AssertQuery( + isAsync, + odt => + from od in odt + where od == null + select od); + + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Entity_equality_not_null(bool isAsync) { @@ -441,6 +466,65 @@ from c in cs select c.CustomerID); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_not_null_composite_key(bool isAsync) + => AssertQuery( + isAsync, + odt => + from od in odt + where od != null + select od, + entryCount: 2155); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_through_nested_anonymous_type_projection(bool isAsync) + => AssertQuery( + isAsync, + o => o + .Select(x => new + { + CustomerInfo = new { + x.Customer + } + }) + .Where(x => x.CustomerInfo.Customer != null), + entryCount: 89); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_through_subquery(bool isAsync) + => AssertQuery( + isAsync, + cs => + from c in cs + where c.Orders.FirstOrDefault() != null + select c.CustomerID); + + [Fact] + public virtual void Entity_equality_through_subquery_composite_key() + { + Assert.Throws(() => + CreateContext().Orders + .Where(o => o.OrderDetails.FirstOrDefault() == new OrderDetail + { + OrderID = 10248, + ProductID = 11 + }) + .ToList()); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_through_include(bool isAsync) + => AssertQuery( + isAsync, + cs => + from c in cs.Include(c => c.Orders) + where c == null + select c.CustomerID); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Null_conditional_simple(bool isAsync) @@ -1876,7 +1960,7 @@ from e1 in es.Take(3) select e1); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual Task Where_query_composition_entity_equality_one_element_FirstOrDefault(bool isAsync) { @@ -1900,7 +1984,7 @@ from e1 in es.Take(3) select e1); } - [ConditionalFact(Skip = "Issue#15588")] + [ConditionalFact] public virtual void Where_query_composition_entity_equality_no_elements_Single() { using (var ctx = CreateContext()) @@ -1913,7 +1997,7 @@ public virtual void Where_query_composition_entity_equality_no_elements_Single() } } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual Task Where_query_composition_entity_equality_no_elements_FirstOrDefault(bool isAsync) { @@ -1925,7 +2009,7 @@ from e1 in es select e1); } - [ConditionalFact(Skip = "Issue#15588")] + [ConditionalFact] public virtual void Where_query_composition_entity_equality_multiple_elements_SingleOrDefault() { using (var ctx = CreateContext()) @@ -1938,7 +2022,7 @@ public virtual void Where_query_composition_entity_equality_multiple_elements_Si } } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual Task Where_query_composition_entity_equality_multiple_elements_FirstOrDefault(bool isAsync) { @@ -3374,7 +3458,7 @@ where EF.Property(e, "Title") == "Sales Representative" [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual Task Select_Property_when_shaow_unconstrained_generic_method(bool isAsync) + public virtual Task Select_Property_when_shadow_unconstrained_generic_method(bool isAsync) { return AssertQuery( isAsync, @@ -3384,7 +3468,7 @@ public virtual Task Select_Property_when_shaow_unconstrained_generic_method(bool [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual Task Where_Property_when_shaow_unconstrained_generic_method(bool isAsync) + public virtual Task Where_Property_when_shadow_unconstrained_generic_method(bool isAsync) { return AssertQuery( isAsync, @@ -5040,7 +5124,7 @@ public virtual Task Int16_parameter_can_be_used_for_int_column(bool isAsync) os => os.Where(o => o.OrderID == parameter), entryCount: 1); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Subquery_is_null_translated_correctly(bool isAsync) { @@ -5056,7 +5140,7 @@ from c in cs entryCount: 2); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Subquery_is_not_null_translated_correctly(bool isAsync) { @@ -5304,7 +5388,7 @@ orderby c1.CustomerID }); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Comparing_different_entity_types_using_Equals(bool isAsync) { @@ -5317,7 +5401,7 @@ where c.Equals(o) select c.CustomerID); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Comparing_entity_to_null_using_Equals(bool isAsync) { @@ -5370,7 +5454,7 @@ where Equals(o1.Customer, o2.Customer) e => e.Id1 + " " + e.Id2); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Comparing_non_matching_entities_using_Equals(bool isAsync) { @@ -5408,7 +5492,7 @@ where c.Orders.Equals(o.OrderDetails) e => e.Id1 + " " + e.Id2); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Comparing_collection_navigation_to_null(bool isAsync) { @@ -5417,7 +5501,7 @@ public virtual Task Comparing_collection_navigation_to_null(bool isAsync) cs => cs.Where(c => c.Orders == null).Select(c => c.CustomerID)); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Comparing_collection_navigation_to_null_complex(bool isAsync) { @@ -5707,7 +5791,7 @@ where details.Any() }); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Let_entity_equality_to_null(bool isAsync) { @@ -5723,7 +5807,7 @@ public virtual Task Let_entity_equality_to_null(bool isAsync) }); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "#15855")] [MemberData(nameof(IsAsyncData))] public virtual Task Let_entity_equality_to_other_entity(bool isAsync) { @@ -5813,7 +5897,7 @@ select g.OrderByDescending(x => x.OrderID), } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Collection_navigation_equal_to_null_for_subquery(bool isAsync) { @@ -5824,7 +5908,7 @@ public virtual Task Collection_navigation_equal_to_null_for_subquery(bool isAsyn entryCount: 2); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory(Skip = "Needs AsQueryable")] [MemberData(nameof(IsAsyncData))] public virtual Task Dependent_to_principal_navigation_equal_to_null_for_subquery(bool isAsync) { @@ -5835,7 +5919,7 @@ public virtual Task Dependent_to_principal_navigation_equal_to_null_for_subquery entryCount: 2); } - [ConditionalTheory(Skip = "Issue#15588")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Collection_navigation_equality_rewrite_for_subquery(bool isAsync) { @@ -5846,5 +5930,15 @@ public virtual Task Collection_navigation_equality_rewrite_for_subquery(bool isA && os.Where(o => o.OrderID < 10300).OrderBy(o => o.OrderID).FirstOrDefault().OrderDetails == os.Where(o => o.OrderID > 10500).OrderBy(o => o.OrderID).FirstOrDefault().OrderDetails)); } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public void Inner_parameter_in_nested_lambdas_gets_preserved(bool isAsync) + { + AssertQuery( + isAsync, + cs => cs.Where(c => c.Orders.Where(o => c == new Customer { CustomerID = o.CustomerID }).Count() > 0), + entryCount: 90); + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/DbFunctionsSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/DbFunctionsSqlServerTest.cs index 104cae5f584..784db1fa40e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/DbFunctionsSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/DbFunctionsSqlServerTest.cs @@ -235,7 +235,7 @@ await Assert.ThrowsAsync( [ConditionalFact] [SqlServerCondition(SqlServerCondition.SupportsFullTextSearch)] - public async Task FreeText_throws_when_using_non_column_for_proeprty_reference() + public async Task FreeText_throws_when_using_non_column_for_propeprty_reference() { using (var context = CreateContext()) { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index d748e180b6a..d7ba6933e4a 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -199,6 +199,13 @@ FROM [Customers] AS [c] WHERE [c].[CustomerID] = N'ANATR'"); } + public override async Task Entity_equality_local_inline_composite_key(bool isAsync) + { + await base.Entity_equality_local_inline_composite_key(isAsync); + + // TODO: AssertSql + } + public override async Task Entity_equality_null(bool isAsync) { await base.Entity_equality_null(isAsync); @@ -209,6 +216,16 @@ FROM [Customers] AS [c] WHERE [c].[CustomerID] IS NULL"); } + public override async Task Entity_equality_null_composite_key(bool isAsync) + { + await base.Entity_equality_null_composite_key(isAsync); + + AssertSql( + @"SELECT [o].[ProductID] +FROM [Order Details] AS [o] +WHERE CAST(0 AS bit) = CAST(1 AS bit)"); + } + public override async Task Entity_equality_not_null(bool isAsync) { await base.Entity_equality_not_null(isAsync); @@ -219,6 +236,53 @@ FROM [Customers] AS [c] WHERE [c].[CustomerID] IS NOT NULL"); } + public override async Task Entity_equality_not_null_composite_key(bool isAsync) + { + await base.Entity_equality_not_null_composite_key(isAsync); + + AssertSql( + @"SELECT [o].[ProductID] +FROM [Order Details] AS [o] +WHERE CAST(1 AS bit) = CAST(1 AS bit)"); + } + + public override async Task Entity_equality_through_nested_anonymous_type_projection(bool isAsync) + { + await base.Entity_equality_through_nested_anonymous_type_projection(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Orders] AS [o] +LEFT JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID] +WHERE [c].[CustomerID] IS NOT NULL"); + } + + public override async Task Entity_equality_through_subquery(bool isAsync) + { + await base.Entity_equality_through_subquery(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID] +FROM [Customers] AS [c] +WHERE ( + SELECT TOP(1) [o].[OrderID] + FROM [Orders] AS [o] + WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL) IS NOT NULL"); + } + + public override async Task Entity_equality_through_include(bool isAsync) + { + await base.Entity_equality_through_include(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID] +FROM [Customers] AS [c] +WHERE ( + SELECT TOP(1) [o].[OrderID] + FROM [Orders] AS [o] + WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL) IS NOT NULL"); + } + public override async Task Queryable_reprojection(bool isAsync) { await base.Queryable_reprojection(isAsync);