diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 5b10b4e372c..4bd104b36eb 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -1358,14 +1358,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ?? methodCallExpression.Update(null!, new[] { source, methodCallExpression.Arguments[1] }); } - if (methodCallExpression.TryGetEFPropertyArguments(out source, out navigationName)) - { - source = Visit(source); - - return TryExpand(source, MemberIdentity.Create(navigationName)) - ?? methodCallExpression.Update(source, new[] { methodCallExpression.Arguments[0] }); - } - return base.VisitMethodCall(methodCallExpression); } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 9adff93a033..c72f3c30cc4 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -1146,6 +1147,7 @@ private static readonly MethodInfo _objectEqualsMethodInfo private readonly ISqlExpressionFactory _sqlExpressionFactory; private SelectExpression _selectExpression; + private DeferredOwnedExpansionRemovingVisitor _visitor; public SharedTypeEntityExpandingExpressionVisitor( RelationalSqlTranslatingExpressionVisitor sqlTranslator, @@ -1154,13 +1156,15 @@ public SharedTypeEntityExpandingExpressionVisitor( _sqlTranslator = sqlTranslator; _sqlExpressionFactory = sqlExpressionFactory; _selectExpression = null!; + _visitor = null!; } public Expression Expand(SelectExpression selectExpression, Expression lambdaBody) { _selectExpression = selectExpression; + _visitor = new(_selectExpression); - return Visit(lambdaBody); + return _visitor.Visit(Visit(lambdaBody)); } protected override Expression VisitMember(MemberExpression memberExpression) @@ -1185,14 +1189,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ?? methodCallExpression.Update(null!, new[] { source, methodCallExpression.Arguments[1] }); } - if (methodCallExpression.TryGetEFPropertyArguments(out source, out navigationName)) - { - source = Visit(source); - - return TryExpand(source, MemberIdentity.Create(navigationName)) - ?? methodCallExpression.Update(source, new[] { methodCallExpression.Arguments[1] }); - } - return base.VisitMethodCall(methodCallExpression); } @@ -1209,6 +1205,11 @@ protected override Expression VisitExtension(Expression extensionExpression) private Expression? TryExpand(Expression? source, MemberIdentity member) { source = source.UnwrapTypeConversion(out var convertedType); + var doee = source as DeferredOwnedExpansionExpression; + if (doee is not null) + { + source = _visitor.UnwrapDeferredEntityProjectionExpression(doee); + } if (source is not EntityShaperExpression entityShaperExpression) { return null; @@ -1294,11 +1295,7 @@ outerKey is NewArrayExpression newArrayExpression Expression.Quote(correlationPredicate)); } - var entityProjectionExpression = (EntityProjectionExpression) - (entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression - ? _selectExpression.GetProjection(projectionBindingExpression) - : entityShaperExpression.ValueBufferExpression); - + var entityProjectionExpression = GetEntityProjectionExpression(entityShaperExpression); var innerShaper = entityProjectionExpression.BindNavigation(navigation); if (innerShaper == null) { @@ -1361,7 +1358,22 @@ outerKey is NewArrayExpression newArrayExpression makeNullable); var joinPredicate = _sqlTranslator.Translate(Expression.Equal(outerKey, innerKey))!; + // Following conditions should match conditions for pushdown on outer during SelectExpression.AddJoin method + var pushdownRequired = _selectExpression.Limit != null + || _selectExpression.Offset != null + || _selectExpression.IsDistinct + || _selectExpression.GroupBy.Count > 0; _selectExpression.AddLeftJoin(innerSelectExpression, joinPredicate); + + // If pushdown was required on SelectExpression then we need to fetch the updated entity projection + if (pushdownRequired) + { + if (doee is not null) + { + entityShaperExpression = _visitor.UnwrapDeferredEntityProjectionExpression(doee); + } + entityProjectionExpression = GetEntityProjectionExpression(entityShaperExpression); + } var leftJoinTable = _selectExpression.Tables.Last(); innerShaper = new RelationalEntityShaperExpression( @@ -1374,13 +1386,115 @@ outerKey is NewArrayExpression newArrayExpression entityProjectionExpression.AddNavigationBinding(navigation, innerShaper); } - return innerShaper; + return doee is not null + ? doee.AddNavigation(targetEntityType, navigation) + : new DeferredOwnedExpansionExpression(targetEntityType, + (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression, + navigation); } private static Expression AddConvertToObject(Expression expression) => expression.Type.IsValueType ? Expression.Convert(expression, typeof(object)) : expression; + + private EntityProjectionExpression GetEntityProjectionExpression(EntityShaperExpression entityShaperExpression) + { + return entityShaperExpression.ValueBufferExpression switch + { + ProjectionBindingExpression projectionBindingExpression + => (EntityProjectionExpression)_selectExpression.GetProjection(projectionBindingExpression), + EntityProjectionExpression entityProjectionExpression => entityProjectionExpression, + _ => throw new InvalidOperationException(), + }; + } + + private sealed class DeferredOwnedExpansionExpression : Expression + { + private readonly IEntityType _entityType; + + public DeferredOwnedExpansionExpression( + IEntityType entityType, + ProjectionBindingExpression projectionBindingExpression, + INavigation navigation) + { + _entityType = entityType; + ProjectionBindingExpression = projectionBindingExpression; + NavigationChain = new List { navigation }; + } + + private DeferredOwnedExpansionExpression( + IEntityType entityType, + ProjectionBindingExpression projectionBindingExpression, + List navigationChain) + { + _entityType = entityType; + ProjectionBindingExpression = projectionBindingExpression; + NavigationChain = navigationChain; + } + + public ProjectionBindingExpression ProjectionBindingExpression { get; } + public List NavigationChain { get; } + + public DeferredOwnedExpansionExpression AddNavigation(IEntityType entityType, INavigation navigation) + { + var navigationChain = new List(NavigationChain.Count + 1); + navigationChain.AddRange(NavigationChain); + navigationChain.Add(navigation); + + return new DeferredOwnedExpansionExpression( + entityType, + ProjectionBindingExpression, + navigationChain); + } + + public override Type Type => _entityType.ClrType; + + public override ExpressionType NodeType => ExpressionType.Extension; + } + + private sealed class DeferredOwnedExpansionRemovingVisitor : ExpressionVisitor + { + private readonly SelectExpression _selectExpression; + + public DeferredOwnedExpansionRemovingVisitor(SelectExpression selectExpression) + { + _selectExpression = selectExpression; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + switch (expression) + { + case DeferredOwnedExpansionExpression doee: + return UnwrapDeferredEntityProjectionExpression(doee); + + // For the source entity shaper skip + case EntityShaperExpression _: + // For owned collection expansion + case ShapedQueryExpression _: + return expression; + + default: + return base.Visit(expression); + } + } + + public EntityShaperExpression UnwrapDeferredEntityProjectionExpression(DeferredOwnedExpansionExpression doee) + { + var entityProjection = (EntityProjectionExpression)_selectExpression.GetProjection(doee.ProjectionBindingExpression); + var entityShaper = entityProjection.BindNavigation(doee.NavigationChain[0])!; + + for (var i = 1; i < doee.NavigationChain.Count; i++) + { + entityProjection = (EntityProjectionExpression)entityShaper.ValueBufferExpression; + entityShaper = entityProjection.BindNavigation(doee.NavigationChain[i])!; + } + + return entityShaper; + } + } } private ShapedQueryExpression TranslateTwoParameterSelector(ShapedQueryExpression source, LambdaExpression resultSelector) diff --git a/test/EFCore.InMemory.FunctionalTests/Query/OwnedEntityQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/OwnedEntityQueryInMemoryTest.cs index 8e220067c2f..cc9c4f31acf 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/OwnedEntityQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/OwnedEntityQueryInMemoryTest.cs @@ -72,5 +72,23 @@ protected class Foo public virtual Bar? Bar { get; set; } } #nullable disable + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Owned_references_on_same_level_expanded_at_different_times_around_take(bool async) + { + var contextFactory = await InitializeAsync(seed: c => c.Seed()); + using var context = contextFactory.CreateContext(); + + await base.Owned_references_on_same_level_expanded_at_different_times_around_take_helper(context, async); + } + + protected class MyContext26592 : MyContext26592Base + { + public MyContext26592(DbContextOptions options) + : base(options) + { + } + } } } diff --git a/test/EFCore.Relational.Specification.Tests/Query/OwnedEntityQueryRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/OwnedEntityQueryRelationalTestBase.cs index a4f604e6b12..d62738c6ea2 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/OwnedEntityQueryRelationalTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/OwnedEntityQueryRelationalTestBase.cs @@ -164,5 +164,33 @@ protected class PublishTokenType25680 public string TokenGroupId { get; set; } public string IssuerName { get; set; } } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Owned_reference_mapped_to_different_table_updated_correctly_after_subquery_pushdown(bool async) + { + var contextFactory = await InitializeAsync(seed: c => c.Seed()); + using var context = contextFactory.CreateContext(); + + await base.Owned_references_on_same_level_expanded_at_different_times_around_take_helper(context, async); + } + + protected class MyContext26592 : MyContext26592Base + { + public MyContext26592(DbContextOptions options) + : base(options) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity( + b => + { + b.OwnsOne(e => e.CustomerData).ToTable("CustomerData"); + b.OwnsOne(e => e.SupplierData).ToTable("SupplierData"); + }); + } + } } } diff --git a/test/EFCore.Specification.Tests/Query/OwnedEntityQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/OwnedEntityQueryTestBase.cs index f549de2e367..b3aaedc7168 100644 --- a/test/EFCore.Specification.Tests/Query/OwnedEntityQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/OwnedEntityQueryTestBase.cs @@ -308,5 +308,75 @@ protected class WarehouseModel public ICollection DestinationCountryCodes { get; set; } } + protected virtual async Task Owned_references_on_same_level_expanded_at_different_times_around_take_helper(MyContext26592Base context, bool async) + { + var query = context.Companies.Where(e => e.CustomerData != null).OrderBy(e => e.Id).Take(10); + var result = async + ? await query.ToListAsync() + : query.ToList(); + + var company = Assert.Single(result); + Assert.Equal("Acme Inc.", company.Name); + Assert.Equal("Regular", company.CustomerData.AdditionalCustomerData); + Assert.Equal("Free shipping", company.SupplierData.AdditionalSupplierData); + } + + protected abstract class MyContext26592Base : DbContext + { + protected MyContext26592Base(DbContextOptions options) + : base(options) + { + } + + public DbSet Companies { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity( + b => + { + b.OwnsOne(e => e.CustomerData); + b.OwnsOne(e => e.SupplierData); + }); + } + + public void Seed() + { + Add(new Company + { + Name = "Acme Inc.", + CustomerData = new CustomerData + { + AdditionalCustomerData = "Regular" + }, + SupplierData = new SupplierData + { + AdditionalSupplierData = "Free shipping" + } + }); + + SaveChanges(); + } + } + + protected class Company + { + public int Id { get; set; } + public string Name { get; set; } + public CustomerData CustomerData { get; set; } + public SupplierData SupplierData { get; set; } + } + + protected class CustomerData + { + public int Id { get; set; } + public string AdditionalCustomerData { get; set; } + } + + protected class SupplierData + { + public int Id { get; set; } + public string AdditionalSupplierData { get; set; } + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/OwnedEntityQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/OwnedEntityQuerySqlServerTest.cs index e30e0f64966..480d948a7db 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/OwnedEntityQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/OwnedEntityQuerySqlServerTest.cs @@ -99,5 +99,24 @@ FROM [Location25680] AS [l] WHERE [l].[Id] = @__id_0 ORDER BY [l].[Id]"); } + + public override async Task Owned_reference_mapped_to_different_table_updated_correctly_after_subquery_pushdown(bool async) + { + await base.Owned_reference_mapped_to_different_table_updated_correctly_after_subquery_pushdown(async); + + AssertSql( + @"@__p_0='10' + +SELECT [t].[Id], [t].[Name], [t].[CompanyId], [t].[AdditionalCustomerData], [t].[Id0], [s].[CompanyId], [s].[AdditionalSupplierData], [s].[Id] +FROM ( + SELECT TOP(@__p_0) [c].[Id], [c].[Name], [c0].[CompanyId], [c0].[AdditionalCustomerData], [c0].[Id] AS [Id0] + FROM [Companies] AS [c] + LEFT JOIN [CustomerData] AS [c0] ON [c].[Id] = [c0].[CompanyId] + WHERE [c0].[CompanyId] IS NOT NULL + ORDER BY [c].[Id] +) AS [t] +LEFT JOIN [SupplierData] AS [s] ON [t].[Id] = [s].[CompanyId] +ORDER BY [t].[Id]"); + } } }