From 7a8058db6e5fcd599d796fed0373613f1ff6070e Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 30 Jul 2019 23:14:59 +0200 Subject: [PATCH] Entity equality support for non-anonymous DTOs MemberAssignmentBinding only for now. Fixes #16789 --- ...osmosProjectionBindingExpressionVisitor.cs | 42 +++--- ...emoryProjectionBindingExpressionVisitor.cs | 42 +++--- ...ionalProjectionBindingExpressionVisitor.cs | 42 +++--- ...ntityEqualityRewritingExpressionVisitor.cs | 125 ++++++++++++++---- .../Internal/ExpressionEqualityComparer.cs | 2 +- .../Query/SimpleQueryCosmosTest.cs | 6 + .../Query/QueryNavigationsTestBase.cs | 5 + .../Query/SimpleQueryTestBase.cs | 17 +++ .../Query/SimpleQuerySqlServerTest.cs | 11 ++ 9 files changed, 211 insertions(+), 81 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index 1e5b06aa28a..4dfa4e71d37 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -331,37 +331,43 @@ protected override Expression VisitNew(NewExpression newExpression) /// protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) { - var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression); + var newExpression = VisitAndConvert(memberInitExpression.NewExpression, nameof(VisitMemberInit)); if (newExpression == null) { return null; } - var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + var newBindings = new MemberBinding[memberInitExpression.Bindings.Count]; for (var i = 0; i < newBindings.Length; i++) { - var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; - if (_clientEval) + newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]); + if (newBindings[i] == null || newBindings[i].BindingType != MemberBindingType.Assignment) { - newBindings[i] = memberAssignment.Update(Visit(memberAssignment.Expression)); + return null; } - else - { - var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); - _projectionMembers.Push(projectionMember); + } - var visitedExpression = Visit(memberAssignment.Expression); - if (visitedExpression == null) - { - return null; - } + return memberInitExpression.Update(newExpression, newBindings); + } - newBindings[i] = memberAssignment.Update(visitedExpression); - _projectionMembers.Pop(); - } + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + { + if (_clientEval) + { + return memberAssignment.Update(Visit(memberAssignment.Expression)); } - return memberInitExpression.Update(newExpression, newBindings); + var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); + _projectionMembers.Push(projectionMember); + + var visitedExpression = Visit(memberAssignment.Expression); + if (visitedExpression == null) + { + return null; + } + + _projectionMembers.Pop(); + return memberAssignment.Update(visitedExpression); } // TODO: Debugging diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs index 1b9acd5edfd..9e497fc8288 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs @@ -209,37 +209,43 @@ protected override Expression VisitNew(NewExpression newExpression) protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) { - var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression); + var newExpression = VisitAndConvert(memberInitExpression.NewExpression, nameof(VisitMemberInit)); if (newExpression == null) { return null; } - var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + var newBindings = new MemberBinding[memberInitExpression.Bindings.Count]; for (var i = 0; i < newBindings.Length; i++) { - var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; - if (_clientEval) + newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]); + if (newBindings[i] == null || newBindings[i].BindingType != MemberBindingType.Assignment) { - newBindings[i] = memberAssignment.Update(Visit(memberAssignment.Expression)); + return null; } - else - { - var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); - _projectionMembers.Push(projectionMember); + } - var visitedExpression = Visit(memberAssignment.Expression); - if (visitedExpression == null) - { - return null; - } + return memberInitExpression.Update(newExpression, newBindings); + } - newBindings[i] = memberAssignment.Update(visitedExpression); - _projectionMembers.Pop(); - } + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + { + if (_clientEval) + { + return memberAssignment.Update(Visit(memberAssignment.Expression)); } - return memberInitExpression.Update(newExpression, newBindings); + var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); + _projectionMembers.Push(projectionMember); + + var visitedExpression = Visit(memberAssignment.Expression); + if (visitedExpression == null) + { + return null; + } + + _projectionMembers.Pop(); + return memberAssignment.Update(visitedExpression); } // TODO: Debugging diff --git a/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs index 21d6b006ba2..f2efa85bc64 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalProjectionBindingExpressionVisitor.cs @@ -232,37 +232,43 @@ protected override Expression VisitNew(NewExpression newExpression) protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) { - var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression); + var newExpression = VisitAndConvert(memberInitExpression.NewExpression, nameof(VisitMemberInit)); if (newExpression == null) { return null; } - var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + var newBindings = new MemberBinding[memberInitExpression.Bindings.Count]; for (var i = 0; i < newBindings.Length; i++) { - var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; - if (_clientEval) + newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]); + if (newBindings[i] == null || newBindings[i].BindingType != MemberBindingType.Assignment) { - newBindings[i] = memberAssignment.Update(Visit(memberAssignment.Expression)); + return null; } - else - { - var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); - _projectionMembers.Push(projectionMember); + } - var visitedExpression = Visit(memberAssignment.Expression); - if (visitedExpression == null) - { - return null; - } + return memberInitExpression.Update(newExpression, newBindings); + } - newBindings[i] = memberAssignment.Update(visitedExpression); - _projectionMembers.Pop(); - } + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + { + if (_clientEval) + { + return memberAssignment.Update(Visit(memberAssignment.Expression)); } - return memberInitExpression.Update(newExpression, newBindings); + var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); + _projectionMembers.Push(projectionMember); + + var visitedExpression = Visit(memberAssignment.Expression); + if (visitedExpression == null) + { + return null; + } + + _projectionMembers.Pop(); + return memberAssignment.Update(visitedExpression); } // TODO: Debugging diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index 239a839237d..51cc930b8df 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -56,13 +57,67 @@ protected override Expression VisitNew(NewExpression newExpression) var visitedArgs = Visit(newExpression.Arguments); var visitedExpression = newExpression.Update(visitedArgs.Select(Unwrap)); - return (newExpression.Members?.Count ?? 0) == 0 + // NewExpression.Members is populated for anonymous types, mapping constructor arguments to the properties + // which receive their values. If not populated, a non-anonymous type is being constructed, and we have no idea where + // its constructor arguments will end up. + if (newExpression.Members == null) + { + return visitedExpression; + } + + var entityReferenceInfo = visitedArgs + .Select((a, i) => (Arg: a, Index: i)) + .Where(ai => ai.Arg is EntityReferenceExpression) + .ToDictionary( + ai => visitedExpression.Members[ai.Index].Name, + ai => EntityOrDtoType.FromEntityReferenceExpression((EntityReferenceExpression)ai.Arg)); + + return entityReferenceInfo.Count == 0 ? (Expression)visitedExpression - : new EntityReferenceExpression(visitedExpression, visitedExpression.Members - .Select((m, i) => (Member: m, Index: i)) - .ToDictionary( - mi => mi.Member.Name, - mi => visitedArgs[mi.Index])); + : new EntityReferenceExpression(visitedExpression, entityReferenceInfo); + } + + protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + { + var visitedNew = Visit(memberInitExpression.NewExpression); + var (visitedBindings, entityReferenceInfo) = VisitMemberBindings(memberInitExpression.Bindings); + var visitedMemberInit = memberInitExpression.Update((NewExpression)Unwrap(visitedNew), visitedBindings); + + return entityReferenceInfo == null + ? (Expression)visitedMemberInit + : new EntityReferenceExpression(visitedMemberInit, entityReferenceInfo); + + // Visits member bindings, unwrapping expressions and surfacing entity reference information via the dictionary + (IEnumerable, Dictionary) VisitMemberBindings(ReadOnlyCollection bindings) + { + var newBindings = new MemberBinding[bindings.Count]; + Dictionary bindingEntityReferenceInfo = null; + + for (var i = 0; i < bindings.Count; i++) + { + switch (bindings[i]) + { + case MemberAssignment assignment: + var visitedAssignment = VisitMemberAssignment(assignment); + if (visitedAssignment.Expression is EntityReferenceExpression ere) + { + if (bindingEntityReferenceInfo == null) + { + bindingEntityReferenceInfo = new Dictionary(); + } + bindingEntityReferenceInfo[assignment.Member.Name] = EntityOrDtoType.FromEntityReferenceExpression(ere); + } + newBindings[i] = assignment.Update(Unwrap(visitedAssignment.Expression)); + continue; + + default: + newBindings[i] = VisitMemberBinding(bindings[i]); + continue; + } + } + + return (newBindings, bindingEntityReferenceInfo); + } } protected override Expression VisitMember(MemberExpression memberExpression) @@ -238,7 +293,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp methodCallExpression.Update(null, newArguments), newSourceWrapper.EntityType, lastNavigation: null, - newSourceWrapper.AnonymousType, + newSourceWrapper.DtoType, subqueryTraversed: true); } } @@ -557,10 +612,10 @@ protected virtual Expression RewriteEquality(bool equality, Expression left, Exp var leftTypeWrapper = left as EntityReferenceExpression; var rightTypeWrapper = right as EntityReferenceExpression; - // If one of the sides is an anonymous object, or both sides are unknown, abort + // If one of the sides is a DTO, or both sides are unknown, abort if (leftTypeWrapper == null && rightTypeWrapper == null - || leftTypeWrapper?.IsAnonymousType == true - || rightTypeWrapper?.IsAnonymousType == true) + || leftTypeWrapper?.IsDtoType == true + || rightTypeWrapper?.IsDtoType == true) { return null; } @@ -786,6 +841,25 @@ protected static Expression Unwrap(Expression expression) _ => expression }; + protected struct EntityOrDtoType + { + public static EntityOrDtoType FromEntityReferenceExpression(EntityReferenceExpression ere) + => new EntityOrDtoType + { + EntityType = ere.IsEntityType ? ere.EntityType : null, + DtoType = ere.IsDtoType ? ere.DtoType : null + }; + + public static EntityOrDtoType FromDtoType(Dictionary dtoType) + => new EntityOrDtoType { DtoType = dtoType }; + + public bool IsEntityType => EntityType != null; + public bool IsDto => DtoType != null; + + public IEntityType EntityType; + public Dictionary DtoType; + } + protected class EntityReferenceExpression : Expression { public sealed override ExpressionType NodeType => ExpressionType.Extension; @@ -808,17 +882,17 @@ protected class EntityReferenceExpression : Expression private readonly INavigation _lastNavigation; [CanBeNull] - public Dictionary AnonymousType { get; } + public Dictionary DtoType { get; } public bool SubqueryTraversed { get; } - public bool IsAnonymousType => AnonymousType != null; + public bool IsDtoType => DtoType != null; public bool IsEntityType => EntityType != null; - public EntityReferenceExpression(Expression underlying, Dictionary anonymousType) + public EntityReferenceExpression(Expression underlying, Dictionary dtoType) { Underlying = underlying; - AnonymousType = anonymousType; + DtoType = dtoType; } public EntityReferenceExpression(Expression underlying, IEntityType entityType) @@ -838,13 +912,13 @@ public EntityReferenceExpression( Expression underlying, IEntityType entityType, INavigation lastNavigation, - Dictionary anonymousType, + Dictionary dtoType, bool subqueryTraversed) { Underlying = underlying; EntityType = entityType; _lastNavigation = lastNavigation; - AnonymousType = anonymousType; + DtoType = dtoType; SubqueryTraversed = subqueryTraversed; } @@ -866,14 +940,13 @@ public virtual Expression TraverseProperty(string propertyName, Expression desti : destinationExpression; } - if (IsAnonymousType) + if (IsDtoType) { - if (AnonymousType.TryGetValue(propertyName, out var expression) - && expression is EntityReferenceExpression wrapper) + if (DtoType.TryGetValue(propertyName, out var entityOrDto)) { - return wrapper.IsEntityType - ? new EntityReferenceExpression(destinationExpression, wrapper.EntityType) - : new EntityReferenceExpression(destinationExpression, wrapper.AnonymousType); + return entityOrDto.IsEntityType + ? new EntityReferenceExpression(destinationExpression, entityOrDto.EntityType) + : new EntityReferenceExpression(destinationExpression, entityOrDto.DtoType); } return destinationExpression; @@ -883,7 +956,7 @@ public virtual Expression TraverseProperty(string propertyName, Expression desti } public EntityReferenceExpression Update(Expression newUnderlying) - => new EntityReferenceExpression(newUnderlying, EntityType, _lastNavigation, AnonymousType, SubqueryTraversed); + => new EntityReferenceExpression(newUnderlying, EntityType, _lastNavigation, DtoType, SubqueryTraversed); protected override Expression VisitChildren(ExpressionVisitor visitor) => Update(visitor.Visit(Underlying)); @@ -896,9 +969,9 @@ public virtual void Print(ExpressionPrinter expressionPrinter) { expressionPrinter.StringBuilder.Append($".EntityType({EntityType})"); } - else if (IsAnonymousType) + else if (IsDtoType) { - expressionPrinter.StringBuilder.Append(".AnonymousObject"); + expressionPrinter.StringBuilder.Append(".DTO"); } if (SubqueryTraversed) @@ -907,7 +980,7 @@ public virtual void Print(ExpressionPrinter expressionPrinter) } } - public override string ToString() => $"{Underlying}[{(IsEntityType ? EntityType.ShortName() : "AnonymousObject")}{(SubqueryTraversed ? ", Subquery" : "")}]"; + public override string ToString() => $"{Underlying}[{(IsEntityType ? EntityType.ShortName() : "DTO")}{(SubqueryTraversed ? ", Subquery" : "")}]"; } } } diff --git a/src/EFCore/Query/Internal/ExpressionEqualityComparer.cs b/src/EFCore/Query/Internal/ExpressionEqualityComparer.cs index 81eac3c2b4a..b82c9171a65 100644 --- a/src/EFCore/Query/Internal/ExpressionEqualityComparer.cs +++ b/src/EFCore/Query/Internal/ExpressionEqualityComparer.cs @@ -711,7 +711,7 @@ private bool CompareBinding(MemberBinding a, MemberBinding b) case MemberBindingType.MemberBinding: return CompareMemberMemberBinding((MemberMemberBinding)a, (MemberMemberBinding)b); default: - throw new NotImplementedException(); + throw new InvalidOperationException("Unhandled member binding type: " + a.BindingType); } } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs index 2ccde2947d0..860abd8cdcf 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs @@ -3827,6 +3827,12 @@ public override Task Entity_equality_through_nested_anonymous_type_projection(bo return base.Entity_equality_through_nested_anonymous_type_projection(isAsync); } + [ConditionalTheory(Skip = "Issue#14935")] + public override async Task Entity_equality_through_DTO_projection(bool isAsync) + { + await base.Entity_equality_through_DTO_projection(isAsync); + } + [ConditionalTheory(Skip = "Issue#14935")] public override Task GroupJoin_customers_employees_shadow(bool isAsync) { diff --git a/test/EFCore.Specification.Tests/Query/QueryNavigationsTestBase.cs b/test/EFCore.Specification.Tests/Query/QueryNavigationsTestBase.cs index 39bf8d97564..8e782d6b838 100644 --- a/test/EFCore.Specification.Tests/Query/QueryNavigationsTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/QueryNavigationsTestBase.cs @@ -1537,5 +1537,10 @@ public virtual Task Multiple_include_with_multiple_optional_navigations(bool isA .Where(od => od.Order.Customer.City == "London"), entryCount: 221); } + + private class OrderDTO + { + public Customer Customer { get; set; } + } } } diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 66ea0e0120a..92cb86e93ae 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -524,6 +524,23 @@ public virtual Task Entity_equality_through_nested_anonymous_type_projection(boo .Where(x => x.CustomerInfo.Customer != null), entryCount: 89); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_through_DTO_projection(bool isAsync) + => AssertQuery( + isAsync, + o => o + .Select(o => new CustomerWrapper { Customer = o.Customer }) + .Where(x => x.Customer != null), + entryCount: 89); + + private class CustomerWrapper + { + public Customer Customer { get; set; } + public override bool Equals(object obj) => obj is CustomerWrapper other && other.Customer.Equals(Customer); + public override int GetHashCode() => Customer.GetHashCode(); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Entity_equality_through_subquery(bool isAsync) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 6eae3ece68a..5b094f3c65e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -297,6 +297,17 @@ FROM [Orders] AS [o] WHERE [c].[CustomerID] IS NOT NULL"); } + public override async Task Entity_equality_through_DTO_projection(bool isAsync) + { + await base.Entity_equality_through_DTO_projection(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Orders] AS [o] +LEFT JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID] +WHERE [c].[CustomerID] IS NOT NULL"); + } + public override async Task Entity_equality_through_subquery(bool isAsync) { await base.Entity_equality_through_subquery(isAsync);