Skip to content

Commit

Permalink
Fix to #5672 - Query :: Include with multiple navigations (including …
Browse files Browse the repository at this point in the history
…optional navigation) fails

Problem was that our include logic for groupjoin was expecting outer and inner elements (on which the include was being performed) to be entities. However sometimes those elements would be TransparentIdentifiers (e.g. when the element comes from a result of SelectMany).

Fix is to also store (when necessary) the accessor from the outer/inner element to the entity that we want to include, extract the actual entity in runtime and apply include operations on that entity instead of the actual outer/inner element.
  • Loading branch information
maumar committed Jul 21, 2016
1 parent 9a34e1d commit c7ff76d
Show file tree
Hide file tree
Showing 13 changed files with 287 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,15 @@ var outer

if (_groupJoinAsyncEnumerable._outerGroupJoinInclude != null)
{
await _groupJoinAsyncEnumerable._outerGroupJoinInclude.IncludeAsync(outer, cancellationToken);
var outerEntityAccessor = _groupJoinAsyncEnumerable._outerGroupJoinInclude.EntityAccessor as Func<TOuter, object>;
if (outerEntityAccessor != null)
{
await _groupJoinAsyncEnumerable._outerGroupJoinInclude.IncludeAsync(outerEntityAccessor(outer), cancellationToken);
}
else
{
await _groupJoinAsyncEnumerable._outerGroupJoinInclude.IncludeAsync(outer, cancellationToken);
}
}

var inner
Expand All @@ -442,7 +450,15 @@ var inner

if (_groupJoinAsyncEnumerable._innerGroupJoinInclude != null)
{
await _groupJoinAsyncEnumerable._innerGroupJoinInclude.IncludeAsync(inner, cancellationToken);
var innerEntityAccessor = _groupJoinAsyncEnumerable._innerGroupJoinInclude.EntityAccessor as Func<TInner, object>;
if (innerEntityAccessor != null)
{
await _groupJoinAsyncEnumerable._innerGroupJoinInclude.IncludeAsync(innerEntityAccessor(inner), cancellationToken);
}
else
{
await _groupJoinAsyncEnumerable._innerGroupJoinInclude.IncludeAsync(inner, cancellationToken);
}
}

inners.Add(inner);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,22 @@ var shaper
{
var groupJoinIncludeArgumentIndex = shaperArgumentIndex + 4;

var existingGroupJoinIncludeArgument = methodCallExpression.Arguments[groupJoinIncludeArgumentIndex];
var existingGroupJoinIncludeWithAccessor = existingGroupJoinIncludeArgument as MethodCallExpression;
var withAccessorMethodInfo = _queryCompilationContext.QueryMethodProvider.GroupJoinIncludeType
.GetTypeInfo().GetDeclaredMethod(nameof(GroupJoinInclude.WithEntityAccessor));

var existingGroupJoinIncludeExpression = existingGroupJoinIncludeWithAccessor != null
&& existingGroupJoinIncludeWithAccessor.Method == withAccessorMethodInfo
? existingGroupJoinIncludeWithAccessor.Object
: existingGroupJoinIncludeArgument;

var groupJoinInclude
= _queryCompilationContext.QueryMethodProvider
.CreateGroupJoinInclude(
_navigationPath,
_querySourceRequiresTracking,
(methodCallExpression.Arguments[groupJoinIncludeArgumentIndex] as ConstantExpression)?.Value,
(existingGroupJoinIncludeExpression as ConstantExpression)?.Value,
_createRelatedEntitiesLoadersMethodInfo
.MakeGenericMethod(_queryCompilationContext.QueryMethodProvider.RelatedEntitiesLoaderType)
.Invoke(this, new object[]
Expand All @@ -161,9 +171,18 @@ var groupJoinInclude

if (groupJoinInclude != null)
{
var newArguments = methodCallExpression.Arguments.ToList();
var groupJoinIncludeExpression = (Expression)Expression.Constant(groupJoinInclude);
var accessorLambda = shaper.GetAccessorExpression(_querySource) as LambdaExpression;
if (accessorLambda != null && accessorLambda.Parameters.Single().Type.GetTypeInfo().IsValueType)
{
groupJoinIncludeExpression = Expression.Call(
groupJoinIncludeExpression,
withAccessorMethodInfo,
shaper.GetAccessorExpression(_querySource));
}

newArguments[groupJoinIncludeArgumentIndex] = Expression.Constant(groupJoinInclude);
var newArguments = methodCallExpression.Arguments.ToList();
newArguments[groupJoinIncludeArgumentIndex] = groupJoinIncludeExpression;

return methodCallExpression.Update(methodCallExpression.Object, newArguments);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class AsyncGroupJoinInclude : IDisposable
private RelationalQueryContext _queryContext;
private IAsyncRelatedEntitiesLoader[] _relatedEntitiesLoaders;
private AsyncGroupJoinInclude _previous;
private Delegate _entityAccessor;

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand All @@ -39,6 +40,23 @@ public AsyncGroupJoinInclude(
_querySourceRequiresTracking = querySourceRequiresTracking;
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public Delegate EntityAccessor => _entityAccessor;

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual AsyncGroupJoinInclude WithEntityAccessor([NotNull] Delegate entityAccessor)
{
_entityAccessor = entityAccessor;

return this;
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class GroupJoinInclude : IDisposable
private RelationalQueryContext _queryContext;
private IRelatedEntitiesLoader[] _relatedEntitiesLoaders;
private GroupJoinInclude _previous;
private Delegate _entityAccessor;

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand All @@ -37,6 +38,23 @@ public GroupJoinInclude(
_querySourceRequiresTracking = querySourceRequiresTracking;
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public Delegate EntityAccessor => _entityAccessor;

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual GroupJoinInclude WithEntityAccessor([NotNull] Delegate entityAccessor)
{
_entityAccessor = entityAccessor;

return this;
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,15 @@ var outer

nextOuter = default(TOuter);

outerGroupJoinInclude?.Include(outer);
var outerAccessor = outerGroupJoinInclude?.EntityAccessor as Func<TOuter, object>;
if (outerAccessor != null)
{
outerGroupJoinInclude?.Include(outerAccessor(outer));
}
else
{
outerGroupJoinInclude?.Include(outer);
}

var inner = innerShaper.Shape(queryContext, sourceEnumerator.Current);
var inners = new List<TInner>();
Expand All @@ -307,7 +315,15 @@ var outer
{
var currentGroupKey = innerKeySelector(inner);

innerGroupJoinInclude?.Include(inner);
var innerAccessor = innerGroupJoinInclude?.EntityAccessor as Func<TInner, object>;
if (innerAccessor != null)
{
innerGroupJoinInclude?.Include(innerAccessor(inner));
}
else
{
innerGroupJoinInclude?.Include(outer);
}

inners.Add(inner);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Specification.Tests.TestModels.Northwind;
using Microsoft.EntityFrameworkCore.Specification.Tests.TestUtilities.Xunit;
using Xunit;

namespace Microsoft.EntityFrameworkCore.Specification.Tests
{
public abstract class AsyncQueryNavigationsTestBase<TFixture> : IClassFixture<TFixture>
where TFixture : NorthwindQueryFixtureBase, new()
{
[ConditionalFact]
public virtual async Task Include_with_multiple_optional_navigations()
{
await AssertQuery<OrderDetail>(
ods => ods
.Include(od => od.Order.Customer)
.Where(od => od.Order.Customer.City == "London"),
entryCount: 164);
}

[ConditionalFact]
public virtual async Task Multiple_include_with_multiple_optional_navigations()
{
await AssertQuery<OrderDetail>(
ods => ods
.Include(od => od.Order.Customer)
.Include(od => od.Product)
.Where(od => od.Order.Customer.City == "London"),
entryCount: 221);
}

protected NorthwindContext CreateContext()
{
return Fixture.CreateContext();
}

protected AsyncQueryNavigationsTestBase(TFixture fixture)
{
Fixture = fixture;
}

protected TFixture Fixture { get; }

protected async Task AssertQuery<TItem>(
Func<IQueryable<TItem>, IQueryable<object>> query,
bool assertOrder = false,
int entryCount = 0,
Action<IList<object>, IList<object>> asserter = null)
where TItem : class
=> await AssertQuery(query, query, assertOrder, entryCount, asserter);

protected async Task AssertQuery<TItem>(
Func<IQueryable<TItem>, IQueryable<object>> efQuery,
Func<IQueryable<TItem>, IQueryable<object>> l2oQuery,
bool assertOrder = false,
int entryCount = 0,
Action<IList<object>, IList<object>> asserter = null)
where TItem : class
{
using (var context = CreateContext())
{
TestHelpers.AssertResults(
l2oQuery(NorthwindData.Set<TItem>()).ToArray(),
await efQuery(context.Set<TItem>()).ToArrayAsync(),
assertOrder,
asserter);

Assert.Equal(entryCount, context.ChangeTracker.Entries().Count());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2473,6 +2473,48 @@ public virtual void Complex_multi_include_with_order_by_and_paging()
}
}

[ConditionalFact]
public virtual void Multiple_include_with_multiple_optional_navigations()
{
List<Level1> expected;
using (var context = CreateContext())
{
expected = context.LevelOne
.Include(e => e.OneToOne_Required_FK).ThenInclude(e => e.OneToMany_Optional)
.Include(e => e.OneToOne_Required_FK).ThenInclude(e => e.OneToOne_Optional_FK)
.Include(e => e.OneToOne_Optional_FK).ThenInclude(e => e.OneToOne_Optional_FK)
.ToList()
.Where(e => e.OneToOne_Required_FK.OneToOne_Optional_PK?.Name != "Foo")
.OrderBy(e => e.Id)
.ToList();
}

ClearLog();

using (var context = CreateContext())
{
var query = context.LevelOne
.Include(e => e.OneToOne_Required_FK).ThenInclude(e => e.OneToMany_Optional)
.Include(e => e.OneToOne_Required_FK).ThenInclude(e => e.OneToOne_Optional_FK)
.Include(e => e.OneToOne_Optional_FK).ThenInclude(e => e.OneToOne_Optional_FK)
.Where(e => e.OneToOne_Required_FK.OneToOne_Optional_PK.Name != "Foo")
.OrderBy(e => e.Id);

var result = query.ToList();
Assert.Equal(expected.Count, result.Count);
for (var i = 0; i < result.Count; i++)
{
Assert.True(expected[i].Id == result[i].Id);
Assert.True(expected[i].Name == result[i].Name);
Assert.True(expected[i].OneToOne_Required_FK?.Id == result[i].OneToOne_Required_FK?.Id);
Assert.True(expected[i].OneToOne_Required_FK?.OneToOne_Optional_FK?.Id == result[i].OneToOne_Required_FK?.OneToOne_Optional_FK?.Id);

Assert.True(expected[i].OneToOne_Optional_FK?.Id == result[i].OneToOne_Optional_FK?.Id);
Assert.True(expected[i].OneToOne_Optional_FK?.OneToOne_Optional_FK?.Id == result[i].OneToOne_Optional_FK?.OneToOne_Optional_FK?.Id);
}
}
}

[ConditionalFact]
public virtual void Correlated_subquery_doesnt_project_unnecessary_columns_in_top_level()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
<ItemGroup>
<Compile Include="AsNoTrackingTestBase.cs" />
<Compile Include="AsTrackingTestBase.cs" />
<Compile Include="AsyncQueryNavigationsTestBase.cs" />
<Compile Include="AsyncQueryTestBase.cs" />
<Compile Include="BuiltInDataTypesFixtureBase.cs" />
<Compile Include="BuiltInDataTypesTestBase.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ public virtual void Select_Where_Navigation_Included()
entryCount: 15);
}

[ConditionalFact]
public virtual void Include_with_multiple_optional_navigations()
{
AssertQuery<OrderDetail>(
ods => ods
.Include(od => od.Order.Customer)
.Where(od => od.Order.Customer.City == "London"),
entryCount: 164);
}

[ConditionalFact]
public virtual void Select_count_plus_sum()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Specification.Tests;
using Xunit.Abstractions;

namespace Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests
{
public class AsyncQueryNavigationsSqlServerTests : AsyncQueryNavigationsTestBase<NorthwindQuerySqlServerFixture>
{
public AsyncQueryNavigationsSqlServerTests(NorthwindQuerySqlServerFixture fixture, ITestOutputHelper testOutputHelper)
: base(fixture)
{
// TestSqlLoggerFactory.CaptureOutput(testOutputHelper);
}
}
}
Loading

0 comments on commit c7ff76d

Please sign in to comment.